5
struct Obj {
  int x;
  int y;
  int z;
};

int Compare(Obj* a, Obj* b) {
  if (a->x > b->x) return 1;
  else if (a->x < b->x) return -1;

  if (a->y > b->y) return 1;
  else if (a->y < b->y) return -1;

  if (a->z > b->z) return 1;
  else if (a->z < b->z) return -1;

  return 0;
}

As shown in the code above, there are three condition branches to get the compare result at most. And the compare func would be called by some sort func. How to optimize the code to kill the condition branch, for which improving the performance of the compare func?

--- update --- since the caller func is an improved version of quick sort, which need the result of greater, less and equality. So the compare func should distinguish three result by -1, 1, 0.

L. F.
  • 19,445
  • 8
  • 48
  • 82
wlhwai
  • 59
  • 4
  • You might combine all conditions that lead to specific return value into a single `if` statement. Use logical OR operator `||`. – vahancho Mar 24 '20 at 15:54
  • 4
    This looks like the same comparison you would get simply comparing `std::tie(a->x, a->y, a->z)` with `std::tie(b->x, b->y, b->z)`. I'm not sure if it would be faster or slower than what you use currently. – François Andrieux Mar 24 '20 at 15:55
  • 1
    Are you sure this is C++? C++-ish `std::sort` takes compare functions that return `bool` and take `const type&` parameters. – KamilCuk Mar 24 '20 at 16:00
  • @KamilCuk But a `bool` return value can't handle three possibilities: `>`, `<` or `==`. – Adrian Mole Mar 24 '20 at 16:10
  • 1
    Aside: you don't actually need any of the `else` statements. – 500 - Internal Server Error Mar 24 '20 at 16:13
  • You could "optimize" it by elimating the branches: E.g. `(a->x > b->x) - (a->x < b->x)` results in -1, 0, or 1 but only arithmetic and comparison left. – Scheff's Cat Mar 24 '20 at 16:15
  • 3
    First you need to build with optimization enabled, and *measure* that this is a top-two or top-three bottleneck. Otherwise my suggestion is to not compromise readability (and therefore maintainability) by attempting to optimize the code. – Some programmer dude Mar 24 '20 at 16:32
  • 1
    After some fiddling, I got the impression that my arithmetic solution is slightly faster than branching: [**demo on coliru**](http://coliru.stacked-crooked.com/a/cbee5f4a8d53f91e) ;-) – Scheff's Cat Mar 24 '20 at 17:00
  • 1
    @Scheff: if it's compiling to asm that looks anything like the source (even though it can see the caller just summing the results), arithmetic might do even better if it's not rare for the x components to be equal, and have to actually fall through to more work. (Especially if the y comparison isn't correlated with the xs being equal). I wonder if it's possible to encode something out of `ax - bx` and `ay - by` etc instead of making the compiler booleanize and shift...) – Peter Cordes Mar 25 '20 at 01:27
  • 1
    @Scheff: https://godbolt.org/z/5uejGP shows that gcc9.3 chooses to inline the arithmetic version, but still `call` the branchy version. So that skews the results some. But it is actually branching in the branchy version and doing `setl` / `setg` / `cmov` in the arithmetic version; even `-O3 -march=skylake` doesn't lead to if-conversion into cmov for the stand-alone version of the branchy function. But those options make GCC auto-vectorize your loop with AVX2 compares / shuffles / blends! That makes arith time at 10ms vs. 40ms on Godbolt's Skylake-Xeon servers. 6ms with AVX512 -march=native – Peter Cordes Mar 25 '20 at 01:29
  • 1
    @wlhwai - What compiler(s) for what architecture(s) do you care about the most? C++ isn't assembly language, and a compiler *could* compile your current function to branchless asm if it wanted to. If you're targeting x86, I suspect you could manually vectorize this if you know it's safe to over-read by 4 bytes (loading a 16-byte XMM vector from the 12 byte struct), using [SSE2 `pcmpgtd`](https://www.felixcloutier.com/x86/pcmpgtb:pcmpgtw:pcmpgtd) both ways, or that and `eq`, and using some kind of bit-scan on a bitmask to find the first difference. – Peter Cordes Mar 25 '20 at 01:52
  • 1
    @Scheff: I noticed a typo in your Arith version: you used `y` twice instead of `z` for the last step. The assert detected this when I changed the rand() to `rand & 0x7` so we'd have a good change to actually get some elements with equal x and y components. This slows down Arith some because it can no longer CSE some of the extra work. https://godbolt.org/z/jwPEpp benchmarks it against a branchless SSE2 version with good throughput but poor latency, doing a bitscan for the first non-equal element and doing `int64_t` subtraction to get a positive, 0, or negative result without overflow. – Peter Cordes Mar 25 '20 at 02:49
  • 1
    @Scheff: posted that as an answer. – Peter Cordes Mar 25 '20 at 03:31
  • @PeterCordes Too bad that I didn't see the typo with `y` instead of `z`. I trusted the `assert()` too much - couldn't imagine a "false negative" for 10000 checks. :-( This proves again: Not finding an error doesn't mean it's not there... :-) – Scheff's Cat Mar 25 '20 at 04:28
  • 1
    @Scheff: A coverage test would have revealed that random testing never happened to reach that last statement, so at least we have to the tools to catch that typo. I noticed it when looking at the asm: why did it only load `[rdi + 0]` and `[rdi + 4]`, but not `[rdi + 8]`? Checked the source, found the problem. Not the first time I've found out that my source didn't mean what I wanted it to from looking at the optimized asm, but that only works for cases where you're going all out optimizing. :P. Re: disabled assert: It fires with `gcc -O3`; you have to manually use `-DNDEBUG` to disable it. – Peter Cordes Mar 25 '20 at 04:37
  • 1
    @Scheff: GCC doesn't have "release" vs. "debug", just optimization levels. `-O0` does have special anti-optimizing behaviour to sync variables to memory for consistent debugging (separate from any `-fblahblah` optimization option), but other than that anything you want to do for extra debugging or less debugging has to be done separately. Unlike MSVC where debug mode poisons uninitialized stack memory with a recognizable value, and stuff like that. – Peter Cordes Mar 25 '20 at 04:41
  • @Scheff thanks for the code demo and performance contrast. Learn a lot. – wlhwai Mar 25 '20 at 11:04

2 Answers2

5

C++ isn't assembly language, and a compiler could compile your current function to branchless asm if it wanted to. (Dereferencing a struct pointer to load one member implies that the entire struct object is present and can thus be speculatively read without risk of faulting even if the C++ abstract machine wouldn't have touched y or z members.) What compiler(s) for what architecture(s) do you care about the most?

Have you tried compiling with profile-guided optimization so the compiler can see that branches are unpredictable? This might lead it to do if-conversion of the if() into branchless cmov or whatever, depending on the target ISA. (Generate your random data with rand() & 0x7 or something so it's not rare for objects to have equal x and equal y, and actually reach the z case.)


It's possible to use SIMD to find the first non-matching element, then return the diff of that element. For example, x86 SIMD has a movemask operation that can turn a vector compare result into an integer bitmask, which we can use with a bitscan instruction to find the first or last set bit.

(This depends on being able to safely read 16 bytes from your 12-byte struct, assuming x86. This is the case as long as your array doesn't end with the last element right at the end of a page, with the next page unmapped. Is it safe to read past the end of a buffer within the same page on x86 and x64? generally yes, and is widely used for efficient implementations of strlen and similar functions.)

(ARM NEON doesn't have a convenient movemask so for ARM / AArch64 you might be better off shuffling data around within a SIMD vector to come up with a result, if SIMD is a win at all. It might not be with ARM's predicated compare instructions, or with AArch64's more limited branchless conditional instructions that are still better than x86 CMOV.)

SIMD can give us good throughput but probably poor latency compared to @Scheff's branchless arithmetic version in comments, especially on a wide pipeline like modern x86 that can do lots of independent work in parallel (like turning separate compare results into boolean integers). High latency might not be ideal in a QSort where you expect branch mispredicts to not be rare; overlapping independent compares with out-of-order execution only works when branches are predicted correctly.

To get a + / 0 / - result from two int values, you can cast to int64_t and subtract. That avoids the possibility of signed overflow, and is efficient on 64-bit ISAs. (Or if it can inline, ideally can compile to just a 32-bit signed compare instead of actual subtraction. 32-bit subtraction could have signed overflow which is UB, and would lose the result on wrapping). If you don't need to normalize to +1 / 0 / -1, do that.

I used an anonymous struct inside a union with an array to extend @Scheff's handy benchmark framework (with bugfix) without changing everything from a->x to a->vals.x.

#include <stdint.h>
#include <immintrin.h>

union Obj {
  struct { // extension: anonymous struct
    int x;
    int y;
    int z;
  };
  int elems[3];
};



// a better check would be on value ranges; sizeof can include padding
static_assert( sizeof(int64_t) > sizeof(int), "we need int smaller than int64_t");

int64_t compare_x86(const Obj *a, const Obj *b)
{
    __m128i va = _mm_loadu_si128((const __m128i*)a);  // assume over-read is safe, last array object isn't at the end of a page.
    __m128i vb = _mm_loadu_si128((const __m128i*)b);
    __m128i veq = _mm_cmpeq_epi32(va,vb);

    unsigned eqmsk = _mm_movemask_ps(_mm_castsi128_ps(veq));
    eqmsk |= 1<<2;   // set elems[2]'s bit so we'll return that (non)diff if they're all equal
    unsigned firstdiff = __builtin_ctz(eqmsk);   // GNU C extension: count trailing zeros

    // sign-extend to 64-bit first so overflow is impossible, giving a +, 0, or - result
    return a->elems[firstdiff] - (int64_t)b->elems[firstdiff];
}

On Godbolt with GCC9.3 -O3 -march=skylake -fno-tree-vectorize for x86-64, it compiles to this asm for the non-inline case:

compare_x86(Obj const*rdi, Obj const*rsi):
        vmovdqu xmm1, XMMWORD PTR [rsi]
        vpcmpeqd        xmm0, xmm1, XMMWORD PTR [rdi]
        vmovmskps       edx, xmm0               # edx = bitmask of the vector compare result
        or      edx, 4
        tzcnt   edx, edx                        # rdx = index of lowest set bit
        mov     edx, edx                        # stupid compiler, already zero-extended to 64-bit
        movsx   rax, DWORD PTR [rdi+rdx*4]      # 32->64 sign extending load
        movsx   rdx, DWORD PTR [rsi+rdx*4]
        sub     rax, rdx                        # return value in RAX
        ret

The latency critical path goes through the SIMD loads + compare, through movemask back to integer, or (1 cycle), tzcnt/bsf (3 cycles on Intel), then another L1d load-use latency for the movsx loads (5 cycles). (numbers from https://agner.org/optimize/ https://uops.info/. See also https://stackoverflow.com/tags/x86/info). The scalar load addresses aren't known until after tzcnt, so there's very little ILP here. Modern x86 can do 2 loads per clock so we are taking advantage of that. It can overlap nicely across independent compares, though, and the total uop count is low so the bottleneck on front-end bandwidth isn't too bad.

The unaligned SIMD loads have no penalty on Intel CPUs unless they cross a cache-line boundary. Then latency is an extra 10 cycles or so. Or worse if they cross a 4k boundary, especially on Intel before Skylake made page splits a lot cheaper. For random 4-byte-aligned object addresses, there are 3 out of 16 start positions that lead to a cache-line split load (for 64B cache lines). This further increases the average latency from the input addresses being ready to the compare result being ready, and can't overlap with any work.

Without -march=skylake GCC uses a separate movdqu unaligned load, and rep bsf which is the same instruction as tzcnt. CPUs without BMI1 will decode it as plain bsf. (They differ only when the input is zero; we make sure that doesn't happen. bsf is slow on AMD, same speed as tzcnt on Intel.)

Using @Scheff's benchmark (which counts the results) on Godbolt, this is somewhat faster than the plain scalar "arithmetic" version when you disable auto-vectorization. (GCC can auto-vec the arithmetic version.) Timing results are inconsistent between runs because the test-case is too small and the AWS servers that compiler explorer runs on might have different CPU frequencies, although they're all Skylake-avx512. But within one run, alternating between this and arith, a result like this is typical:

compare_x86() 5. try: 28 mus (<: 3843, >: 3775)
compareArithm() 5. try: 59 mus (<: 4992, >: 5007)
compare_x86() 6. try: 39 mus (<: 3843, >: 3775)
compareArithm() 6. try: 64 mus (<: 4992, >: 5007)
compare_x86() 7. try: 27 mus (<: 3843, >: 3775)
compareArithm() 7. try: 64 mus (<: 4992, >: 5007)

But remember, this is just adding up the <0 and >0 return values, and thus is throughput bound, not latency. A new compare can start without any data dependency or control dependency on the previous compare result.


Hmm, I could have use pmovmskb to get the high bit of every byte, instead of every dword with the ps version, but C makes it inconvenient to use a byte offset into an int array instead of an element offset. In asm you'd tzcnt or BSF and then movsx rax, [rdi + rdx]. This might save a cycle of latency in bypass delay between SIMD-integer pcmpeqd and SIMD-FP movmskps. But to get that from a compiler you'd maybe have to cast to char* for the pointer addition then back to int*.


I thought at first of using _mm_cmpgt_epi32(va,vb) to get a vector of 0 / -1 compare results for signed greater-than, but then I realized that indexing the original structs would be just as easy as mapping the right element or bit of that into a -1 / +1 integer.

If you wanted to special case the all-equal case, you might set bit #3 instead (|= 1<<3), then branch on that rare case but still do the rest branchlessly.

    eqmsk |= 1<<3;   // set the 4th bit so there's a non-zero bit to find
    unsigned firstdiff = __builtin_ctz(eqmsk);

    if (firstdiff >= 3)   // handle this rare(?) case with a branch
        return 0;

    ... something with  (a < b) * 2 - 1

Mixed branchy strategy:

If it's rare that the xs are equal, perhaps consider

   if (a->x != b->x)
       return  a->x - (int_fast64_t)b->x;
   else {
       8-byte branchless SIMD?
       or maybe just 2 element branchless scalar
   }

IDK if it's worth doing SIMD at all for only 2 more elements. Probably not.

Or perhaps consider doing branchless for x and y, and branching on y components being equal to skip scalar z? If your objects are random over most of the range of int, it's going to be rare that you find two that only differ in the last component.

I think the way good sorting algorithms do fewer comparisons by avoiding redundant comparisons probably creates more entropy in the pattern of results, and probably also increases the amount of comparisons done with elements that are "close" to each other in the final sort order. So QSort could be doing more comparisons that do need to check y elements if there are many elements with equal x.

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
0

Here's a generic emulation of three-way comparison:

#include <tuple>

namespace util {
    template <typename T>
    int compare(const T& lhs, const T& rhs)
    {
        if (lhs == rhs) {
            return 0;
        } else if (lhs < rhs) {
            return -1;
        } else {
            return 1;
        }
    }

    namespace detail {
        template <typename Tuple>
        int compare_tuples(const Tuple&, const Tuple&, std::index_sequence<>)
        {
            return 0;
        }
        template <typename Tuple, std::size_t I, std::size_t... Is>
        int compare_tuples(const Tuple& lhs, const Tuple& rhs, std::index_sequence<I, Is...>)
        {
            if (auto cmp = compare(std::get<I>(lhs), std::get<I>(rhs))) {
                return cmp;
            } else {
                return compare_tuples(lhs, rhs, std::index_sequence<Is...>{});
            }
        }
    }

    template <typename Tuple>
    int compare_tuples(const Tuple& lhs, const Tuple& rhs)
    {
        return detail::compare_tuples(
            lhs, rhs, std::make_index_sequence<std::tuple_size_v<Tuple>>{}
        );
    }
}

Then you can use it by using std::tie to form a tuple of the members:

struct Object {
    int x, y, z;
};

int compare(const Object& lhs, const Object& rhs)
{
    return util::compare_tuples(
        std::tie(lhs.x, lhs.y, lhs.z),
        std::tie(rhs.x, rhs.y, rhs.z)
    );
}

(live demo)

The compare function is eventually optimized to this by GCC:

compare(Object const&, Object const&):
        mov     eax, DWORD PTR [rsi]
        cmp     DWORD PTR [rdi], eax
        je      .L11
.L2:
        setge   al
        movzx   eax, al
        lea     eax, [rax-1+rax]
        ret
.L11:
        mov     eax, DWORD PTR [rsi+4]
        cmp     DWORD PTR [rdi+4], eax
        jne     .L2
        mov     edx, DWORD PTR [rsi+8]
        xor     eax, eax
        cmp     DWORD PTR [rdi+8], edx
        jne     .L2
        ret

and Clang:

compare(Object const&, Object const&):                 # @compare(Object const&, Object const&)
        mov     ecx, dword ptr [rdi]
        xor     eax, eax
        cmp     ecx, dword ptr [rsi]
        setge   cl
        jne     .LBB0_1
        mov     ecx, dword ptr [rdi + 4]
        xor     eax, eax
        cmp     ecx, dword ptr [rsi + 4]
        setge   cl
        jne     .LBB0_1
        mov     eax, dword ptr [rdi + 8]
        xor     ecx, ecx
        xor     edx, edx
        cmp     eax, dword ptr [rsi + 8]
        setge   dl
        lea     eax, [rdx + rdx - 1]
        cmove   eax, ecx
        ret

Since C++20, this problem can be easily solved by the default spaceship operator:

#include <compare>

struct Obj {
    int x;
    int y;
    int z;

    constexpr auto operator<=>(const Obj&) const = default;
};

int to_int(std::partial_ordering cmp) noexcept
{
    if (cmp == 0) {
        return 0;
    } else if (cmp < 0) {
        return -1;
    } else {
        return 1;
    }
}

int Compare(Obj* a, Obj* b)
{
    return to_int(*a <=> *b);
}
L. F.
  • 19,445
  • 8
  • 48
  • 82