6

There is a recent publication at nature.com, Faster sorting algorithms discovered using deep reinforcement learning, where it talks about AlphaDev discovering a faster sorting algorithm. This caught my interest, and I've been trying to understand the discovery.

Other articles about the topic are:

Here is the pseudo-code of the original sort3 algorithm against the improved algorithm that AlphaDev discovered.

enter image description here

Original Pseudo-Code

Memory [0] = A
Memory [1] = B
Memory [2] = C

mov Memory[0] P  // P = A
mov Memory[1] Q  // Q = B
mov Memory[2] R  // R = C

mov R S
cmp P R
cmovg P R  // R = max(A, C)
cmovl P S  // S = min(A, C)
mov S P    // P = min(A, C)
cmp S Q
cmovg Q P  // P = min(A, B, C)
cmovg S Q  // Q = max(min(A, C), B)

mov P Memory[0]  // = min(A, B, C)
mov Q Memory[1]  // = max(min(A, C), B)
mov R Memory[2]  // = max(A, C)

AlphaDev Pseudo-Code

Memory [0] = A
Memory [1] = B
Memory [2] = C

mov Memory[0] P  // P = A
mov Memory[1] Q  // Q = B
mov Memory[2] R  // R = C

mov R S
cmp P R
cmovg P R  // R = max(A, C)
cmovl P S  // S = min(A, C)

cmp S Q
cmovg Q P  // P = min(A, B)
cmovg S Q  // Q = max(min(A, C), B)

mov P Memory[0]  // = min(A, B)
mov Q Memory[1]  // = max(min(A, C), B)
mov R Memory[2]  // = max(A, C)

The improvement centers around the omission of the single move command, mov S P. To help understand, I wrote the following assembly code. However, my testing shows that the sorting algorithm does not work when A=3, B=2, and C=1, but it does work when A=3, B=1, and C=2.

This is written, compiled, and run on Ubuntu 20.04 Desktop.

$ lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 20.04.6 LTS
Release:    20.04
Codename:   focal
$ nasm -v
NASM version 2.14.02
$ ld -v
GNU ld (GNU Binutils for Ubuntu) 2.34

My assembly code test...

; -----------------------------------------------------------------
;
; sort_test.asm
;
; Test for AlphaDev sorting algorithm
;
; My findings are that AlphaDev's removal of 'mov S P' doesn't work when:
;   a = 3, b = 2, c = 1
; But it does work with:
;   a = 3, b = 1, c = 2
;
; Output: The sorted values of a, b & c printed to stdout with spaces
;
; Compile & run with:
;
; nasm -f elf32 sort_test.asm && ld -m elf_i386 sort_test.o -o sort_test && ./sort_test
;
; -----------------------------------------------------------------

global _start

section .data
  a equ 3
  b equ 2
  c equ 1

section .bss
  buffer resb 5

section .text
_start:
; ------------------- ; AlphaDev pseudo-code

  mov eax, a          ; P = A
  mov ecx, b          ; Q = B
  mov edx, c          ; R = C
  mov ebx, edx        ; mov R S

  cmp eax, edx        ; cmp P R
  cmovg edx, eax      ; cmovg P R  // R = max(A, C)
  cmovl ebx, eax      ; cmovl P S  // S = min(A, C)

; The following line was in original sorting algorithm,
; but AlphaDev determined it wasn't necessary
;  mov eax, ebx       ; mov S P   // p = min(A, C)

  cmp ebx, ecx        ; cmp S Q
  cmovg eax, ecx      ; cmovg Q P  // P = min(A, B)
  cmovg ecx, ebx      ; cmovg S Q  // Q = max(min(A, C), B)

; add values to buffer with spaces
  add eax, 30h
  mov [buffer], eax
  mov [buffer+1], byte 0x20

  add ecx, 30h
  mov [buffer+2], ecx
  mov [buffer+3], byte 0x20

  add edx, 30h
  mov [buffer+4], edx

; write buffer to stdout
  mov eax, 4      ; sys_write system call
  mov ebx, 1      ; stdout file descriptor
  mov ecx, buffer ; buffer to write
  mov edx, 5      ; number of bytes to write
  int 0x80

  mov eax, 1      ; sys_exit system call
  mov ebx, 0      ; exit status 0
  int 0x80

I've run this test on the command line to print the results of the sort, but I also used gdb to step through this executable line-by-line. During this debugging, I clearly see that the register for "A", aka "P", aka "eax", is never updated when A=3, B=2, and C=1 but is updated when A=3, B=1, and C=2.

Full disclosure... I'm not an assembly programmer. I'm also not proficient in any other specific language, but I've experimented with C, C++, Javascript, PHP, & HTML to get small projects done. Basically, I'm self taught on what I do know. To get to the point to write this test, I've had to learn quite a bit. Therefore, I could certainly be making mistakes or not understanding the problem.

Anyway, please help me understand why I'm observing what I am.

  • Am I misunderstanding the problem?
  • Am I misunderstanding the pseudo-code?
  • Am I making a mistake transforming the pseudo-code into assembly?
  • Is there a mistake with my assembly code?
  • Is the pseudo-code wrong?
halfer
  • 19,824
  • 17
  • 99
  • 186
mpboden
  • 163
  • 3

3 Answers3

8

TL:DR: they're confusingly only showing the last 2 of 3 comparators in a 3-element sorting network, not a complete 3-element sort. This is presented very misleadingly, including in the diagram in their paper.


I'd have used AT&T syntax (like cmovg %ecx, %eax in a .s file assembled with GCC) so the operand order can match the pseudocode, destination on the right.

You're correct, I had a look at the article and the 3-element pseudocode doesn't sort correctly when C is the smallest element. I know x86-64 asm backwards can forwards, and I don't just mean Intel vs. AT&T syntax :P Even looking at the real code, not just the comments, there's no way for the smallest element to end up in memory[0] = P if it started in R = memory[2] = C.

I opened the article before really reading what your question was asking, and noticed that problem myself after skimming the article until getting to the part about the actual improvement, so I haven't looked at your attempt to replicate it. But I didn't have any bias towards seeing a problem in it, I just wanted to understand it myself. There aren't any instructions writing P that read from values that could contain the starting R value, so there's no way it can get that value.


The article indirectly links their paper published in Nature (Faster sorting algorithms discovered using deep reinforcement learning by Daniel J. Mankowitz, et. al.) The full text is there in the Nature link.

They use the same image of code in the actual paper, but with some explanatory text and diagram in terms of a 3-element sorting network.

image

Figure 3a presents an optimal sorting network for three elements (see Methods for an overview of sorting networks). We will explain how AlphaDev has improved the circled network segment. There are many variants of this structure that are found in sorting networks of various sizes, and the same argument applies in each case.

The circled part of the network (last two comparators) can be seen as a sequence of instructions that takes an input sequence ⟨A, B, C⟩ and transforms each input as shown in Table 2a (left). However, a comparator on wires B and C precedes this operator and therefore input sequences where B ≤ C are guaranteed. This means that it is enough to compute min(A, B) as the first output instead of min(A, B, C) as shown in Table 2a (right). The pseudocode difference between Fig. 3b,c demonstrates how the AlphaDev swap move saves one instruction each time it is applied.

So this pseudocode is just for the circled part of the sorting network, the last 2 of 3 compare-and-swap steps. In their blog article, and even in other parts of the paper like Table 2, they make it sound like this is the whole sort, not just the last 2 steps. The pseudocode even confusingly starts with values in memory, which wouldn't be the case after conditionally swapping B and C to ensure B <= C.


Also, it's unlikely just a mov instruction is a huge speedup in a 3-element sort. Can x86's MOV really be "free"? Why can't I reproduce this at all? - it's never free (it costs front-end bandwidth), but it has zero latency on most recent microarchitectures other than Ice Lake. I'm guessing this wasn't the case where they got a 70% speedup!


With AVX SIMD instructions like vpminsd dst, src1, src2 (https://www.felixcloutier.com/x86/pminsd:pminsq) / vpmaxsd to do min and max of Signed Dword (32-bit) elements with a non-destructive separate destination, there's no saving except critical-path latency. min(B, prev_result) is still just one instruction, no separate register-copy (vmovdqa xmm0, xmm1) needed like it could be with just SSE4.1 if you were doing a sorting-network. But latency could perhaps be significant when building a sorting network out of shuffles and SIMD min/max comparators, which last I heard was the state of the art in integer sorting for large integer or FP arrays on x86-64, not just saving a mov in scalar cmov code!

But lots of programs are compiled not to assume AVX is available, because unfortunately it's not universally supported, missing on some low-power x86 CPUs from as recently as the past couple years, and on Pentium / Celeron CPUs before Ice Lake (so maybe as recent as 2018 or so for low-budget desktop CPUs.)

Their paper in Nature mentions SIMD sorting networks, but points out that libc++ std::sort doesn't take advantage of it, even for the case where the input is an array of float or int, rather than classes with an overloaded operator <.


This 3-element tweak is a micro-optimization, not a "new sorting algorithm". It might still save latency on AArch64, but only instructions on x86

It's nice that AI can find these micro-optimizations, but they wouldn't be nearly as surprising if presented as having a choice between selecting from min(A,C) or min(B,C) because the latter is what B actually is at that point.

Avoiding register-copy instructions with careful choice of higher-level source is something humans can do, e.g. the choice of _mm_movehl_ps merge destination (first source operand) in my 2016 answer on Fastest way to do horizontal SSE vector sum (or other reduction) - see the comment on the compiler-generated asm # note the reuse of shuf, avoiding a movaps.

Previous work in automated micro-optimization includes STOKE, a stochastic superoptimizer that randomly tries instruction sequences hoping to find cheap sequences that match the outputs of a test function you give it. The search space is so large that it tends to miss possible sequences when it takes more than 3 or 4 instructions (STOKE's own page says it's not production-ready, just a research prototype). So AI is helpful. And it's a lot of work to look at asm by hand for possible missed optimizations that could be fixed by tweaking the source.

But at least for this 3-element subproblem, it is just a micro-optimization, not truly algorithmically new. It's still just a 3-comparator sorting network. One that compiles more cheaply for x86-64, which is nice. But on some 3-operand ISAs with a separate destination for their equivalent of cmov, like AArch64's csel dst, src1, src2, flag_condition conditional-select, there's no mov to save. It could still save latency on the critical path, though.

Their paper in Nature also shows an algorithmic difference for sorting a variable number of elements, where the >= 3 cases both start by sorting the first 3. Maybe this helps branch prediction since that work can be in flight while a final branch on len > 3 is resolving to see whether they need to do a simplified 4-element sort that can assume the first 3 elements are sorted. They say "It is this part of the routine that results in significant latency savings." (They also call this a "fundamentally new" algorithm, which I presume is true for the problem of using sorting networks on short unknown-length inputs.)

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
  • Interesting. So I'm not off-base here. Thank you. But I also notice that a patch has been integrated into the [LLVM standard C++ sort library](https://reviews.llvm.org/D118029). So it appears that there is definitely an improvement with AlphaDev's discovery. I attempted to look at the patches, but the C++ quickly went over my head. Anyway, the third link I provided attempts to explain the improvement. They literally break down the pseudo-code, but the logic falls apart. – mpboden Jun 22 '23 at 04:18
  • 1
    @mpboden: After some more reading, it turns out their diagram is just misleading, and is showing the last 2 of 3 comparators of a sorting network. There is an earlier step that ensures `B <= C`. It's bad enough in their paper, much worse in the article you linked which leaves out that context and appears to show that 2-step sorting network as a stand-alone 3-element sort. See the update to my answer. – Peter Cordes Jun 22 '23 at 04:19
  • 1
    @mpboden: If they hadn't left that out, it wouldn't "look like a mistake", and would just be a pretty normal optimization that an asm expert could have made by looking at the compiler output and thinking about how to save some of the `mov` register-copy instructions. I've done optimizations like that in answers like ([SSE horizontal sums](https://stackoverflow.com/q/6996764) in 2016 - note the `// Without AVX, we can save a MOVAPS with MOVHLPS into a dead register` part in the `__m128d` version, and similar careful choice of merge destination for `_mm_movehl_ps` in `hsum_ps_sse1`.) – Peter Cordes Jun 22 '23 at 04:24
  • Ok, I now see where I misunderstood the problem. I first found out about this topic from a YouTube video, then from the article that my 2nd link points to. Those two pieces sent me down the research path. Next, I found the other two articles, but overlooked that one detail of the "comparator on wires B and C precedes this operator and therefore input sequences where B ≤ C are guaranteed". This is a key detail for sure!! Thanks again. You've given me lots more to read and learn about. – mpboden Jun 22 '23 at 04:34
  • 1
    I had such high hopes when I read the title above this Q/A. What a bummer... – Sep Roland Jun 24 '23 at 21:19
  • @SepRoland: Same, I was expecting something like an alternative to QuickSort / MergeSort / etc., not just micro-optimizations for small (sub)-problems. – Peter Cordes Jun 24 '23 at 21:32
  • 1
    You might be interested on this related post on LinkedIn and others mentioned therein https://www.linkedin.com/posts/cassioneri_compiler-explorer-c-activity-7081255939965612032-coPd – Cassio Neri Jul 02 '23 at 16:24
4

The pseudocode for the AlphaDev Sort3 algorithm in Deepmind's Nature article has a bug that affects the sorting of three numbers when the first one is larger than the other two equal ones. That is, A>B=C (which satisfies the precondition B<=C). For example, if the input is (2,1,1), the output is (2,1,2); not sorted and actually corrupted.

The bug is in line 12: cmovg Q P, which should be cmovge Q P, to ensure P is overwritten in this case. Reintroducing the original mov S P instruction would also fix the problem, but would defeat the purpose of the change. A similar bug and fix apply to the pseudocode for the Sort8 snippet in figure 3 d,e,f of the article (cmovl R T should be cmovle R T).

The Sort3 code on AlphaDev’s Github and in Supplement G to the Nature article do not have this problem, and presumably neither does the code in the LLVM libc++ library.


Step by step:

# before first shown step  (after the B,C comparator)
P = 2
Q = 1
R = 1

# after first step
  # P=2 still  unmodified by this step, still just A
  # Q=1 still  still just B
R = 2   # max(A,C) = max(A, max(B, C)) since B<=C
S = 1   # min(A,C) = min(A, max(B, C))

 cmp S, Q(B)  is  cmp 1, 1  # the Greater condition is false
# after final step
P = 2 unchanged by cmovg Q(1), P(2)   !!!! Bug here
Q = 1 unchanged by cmovg S(1), Q(1)
# R = 2 still

cmovge Q, P would have copied the 1 instead of keeping the 2. For an ascending sort, we wanted PQR = 1,1,2

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
JWel
  • 41
  • 2
1

I aded to the AlphaDev wikipedia article this blog post:

Understanding DeepMind's Sorting Algorithm

It translates the pseudo-assembly into actual assembly:


/   move37.S
        .equ    P,%rax
        .equ    Q,%rcx
        .equ    R,%rdx
        .equ    S,%rsi
move37: mov (%rdi),P
        mov 8(%rdi),Q
        mov 16(%rdi),R
        mov R,S
        cmp P,R
        cmovg   P,R
        cmovl   P,S
        cmp S,Q
        cmovg   Q,P
        cmovg   S,Q
        mov R,(%rdi)
        mov Q,8(%rdi)
        mov P,16(%rdi)
        ret
        .type   move37,@function
        .size   move37,.-move37
        .globl  move37
Bastiaan Quast
  • 2,802
  • 1
  • 24
  • 50