5

Consider two numbers written in binary (MSB at left):

X = x7 x6 x5 x4 x3 x2 x1 x0

and

Y = y7 y6 y5 y4 y3 y2 y1 y0

These numbers can have an arbitrary number of bits but both are of the same type. Now consider that x7 == y7, x6 == y6, x5 == y5, but x4 != y4.

How to compute:

Z = x7 x6 x5 0 0 0 0 0

or in other words, how to compute efficiently a number that keeps the common part at the left of the last different bit ?

template <typename T>
inline T f(const T x, const T y) 
{
    // Something here
}

For example, for:

x = 10100101
y = 10110010

it should return

z = 10100000

Note: it is for supercomputing purpose and this operation will be executed hundreds of billion times so scanning the bits one by one should be avoided...

Tomas
  • 57,621
  • 49
  • 238
  • 373
Vincent
  • 57,703
  • 61
  • 205
  • 388

5 Answers5

6

My answer is based on @JerryCoffin's one.

int d = x ^ y;
d = d | (d >> 1);
d = d | (d >> 2);
d = d | (d >> 4);
d = d | (d >> 8);
d = d | (d >> 16);
int z = x & (~d);
Egor Skriptunoff
  • 23,359
  • 2
  • 34
  • 64
  • That's actually pretty cool. It basically finds the MSB in `O(log w)` time and does it in a way that in the process gives you the final bitmask you need. Very elegant – Niklas B. Feb 03 '14 at 06:48
  • Well, but if this optiomization is only optical. If you would count the total bitshifts by one bit `>>=1` and other single operations that will be needed for this and for my solution, you will find this more demanding. – Tomas Feb 03 '14 at 06:58
  • 2
    @Tomas - Modern CPUs perform bitshift in constant time independently of shift counter. – Egor Skriptunoff Feb 03 '14 at 07:17
3

Part of this problem shows up semi-regularly in bit-manipulation: "parallel suffix with OR", or "prefix" (that is, depending on who you listen to, the low bits are either called a suffix or a prefix). Obviously once you have a way to do that, it's trivial to extend it to what you want (as shown in the other answers).

Anyway, the obvious way is:

x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16

But you're probably not constrained to simple operators.

For Haswell, the fastest way I found was:

lzcnt rax, rax     ; number of leading zeroes, sets carry if rax=0
mov edx, 64
sub edx, eax
mov rax, -1
bzhi rax, rax, rdx ; reset the bits in rax starting at position rdx

Other contenders were:

mov rdx, -1
bsr rax, rax       ; position of the highest set bit, set Z flag if no bit
cmovz rdx, rax     ; set rdx=rax iff Z flag is set
xor eax, 63
shrx rax, rdx, rax ; rax = rdx >> rax

And

lzcnt rax, rax
sbb rdx, rdx       ; rdx -= rdx + carry (so 0 if no carry, -1 if carry)
not rdx
shrx rax, rdx, rax

But they were not as fast.

I've also considered

lzcnt rax, rax
mov rax, [table+rax*8]

But it's hard to compare it fairly, since it's the only one that spends cache space, which has non-local effects.

Benchmarking various ways to do this led to this question about some curious behaviour of lzcnt.

They all rely on some fast way to determine the position of the highest set bit, which you could do with a cast to float and exponent extraction if you really had to, so probably most platforms can use something like it.

A shift that gives zero if the shift-count is equal to or bigger than the operand size would be very nice to solve this problem. x86 doesn't have one, but maybe your platform does.

If you had a fast bit-reversal instruction, you could do something like: (this isn't intended to be ARM asm)

rbit r0, r0
neg r1, r0
or r0, r1, r0
rbit r0, r0
Community
  • 1
  • 1
harold
  • 61,398
  • 6
  • 86
  • 164
2

Comparing several algorithms leads to this ranking:

Having an inner loop of 1 or 10 in the test below:

  1. Utilizing a built in bit scan function.
  2. Filling least significant bits with or and shift (The function of @Egor Skriptunoff).
  3. Involving a lookup table.
  4. Scanning the most significant bit (The second function of @Tomas).

InnerLoops = 10:

Timing 1: 0.101284
Timing 2: 0.108845
Timing 3: 0.102526
Timing 4: 0.191911

An inner loop of 100 or greater:

  1. Utilizing a built in bit scan function.
  2. Involving a lookup table.
  3. Filling least significant bits with or and shift (The function of @Egor Skriptunoff).
  4. Scanning the most significant bit (The second function of @Tomas).

InnerLoops = 100:

Timing 1: 0.441786
Timing 2: 0.507651
Timing 3: 0.548328
Timing 4: 0.593668

The test:

#include <algorithm>
#include <chrono>
#include <limits>
#include <iostream>
#include <iomanip>

// Functions
// =========

inline unsigned function1(unsigned  a, unsigned b)
{
    a ^= b;
    if(a) {
        int n = __builtin_clz (a);
        a = (~0u) >> n;
    }
    return ~a & b;
}

typedef std::uint8_t byte;
static byte msb_table[256] = {
    0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
    6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
    7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
    7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
};

inline unsigned function2(unsigned a, unsigned  b)
{
    a ^= b;
    if(a) {
        unsigned n = 0;
        if(a >> 24) n = msb_table[byte(a >> 24)] + 24;
        else if(a >> 16) n = msb_table[byte(a >> 16)] + 16;
        else if(a >> 8) n = msb_table[byte(a >> 8)] + 8;
        else n = msb_table[byte(a)];
        a = (~0u) >> (32-n);
    }
    return ~a & b;
}

inline unsigned function3(unsigned  a, unsigned  b)
{
    unsigned d = a ^ b;
    d = d | (d >> 1);
    d = d | (d >> 2);
    d = d | (d >> 4);
    d = d | (d >> 8);
    d = d | (d >> 16);
    return a & (~d);;
}

inline unsigned function4(unsigned  a, unsigned  b)
{
    const unsigned maxbit = 1u << (std::numeric_limits<unsigned>::digits - 1);
    unsigned msb = maxbit;
    a ^= b;
    while( ! (a & msb))
        msb >>= 1;
    if(msb == maxbit) return 0;
    else {
        msb <<= 1;
        msb  -= 1;
        return ~msb & b;
    }
}


// Test
// ====

inline double duration(
    std::chrono::system_clock::time_point start,
    std::chrono::system_clock::time_point end)
{
    return double((end - start).count())
        / std::chrono::system_clock::period::den;
}

int main() {
    typedef unsigned (*Function)(unsigned , unsigned);
    Function fn[] = {
        function1,
        function2,
        function3,
        function4,
    };
    const unsigned N = sizeof(fn) / sizeof(fn[0]);
    std::chrono::system_clock::duration timing[N] = {};
    const unsigned OuterLoops = 1000000;
    const unsigned InnerLoops = 100;
    const unsigned Samples = OuterLoops * InnerLoops;
    unsigned* A = new unsigned[Samples];
    unsigned* B = new unsigned[Samples];
    for(unsigned i = 0; i < Samples; ++i) {
        A[i] = std::rand();
        B[i] = std::rand();
    }
    unsigned F[N];
    for(unsigned f = 0; f < N; ++f) F[f] = f;
    unsigned result[N];
    for(unsigned i = 0; i < OuterLoops; ++i) {
        std::random_shuffle(F, F + N);
        for(unsigned f = 0; f < N; ++f) {
            unsigned g = F[f];
            auto start = std::chrono::system_clock::now();
            for(unsigned j = 0; j < InnerLoops; ++j) {
                unsigned index = i + j;
                unsigned a = A[index];
                unsigned b = B[index];
                result[g] = fn[g](a, b);
            }
            auto end = std::chrono::system_clock::now();
            timing[g] += (end-start);
        }
        for(unsigned f = 1; f < N; ++f) {
            if(result[0] != result[f]) {
                std::cerr << "Different Results\n" << std::hex;
                for(unsigned g = 0; g < N; ++g)
                    std::cout << "Result " << g+1 << ": " << result[g] << '\n';
                exit(-1);
            }
        }
    }

    for(unsigned i = 0; i < N; ++i) {
        std::cout
            << "Timing " << i+1 << ": "
            << double(timing[i].count()) / std::chrono::system_clock::period::den
            << "\n";
    }
}

Compiler:

g++ 4.7.2

Hardware:

Intel® Core™ i3-2310M CPU @ 2.10GHz × 4 7.7 GiB

  • You can improve `function2` by directly precomputing the bit masks rather than the bit counts (1s in the MSb's and 0s in the LSb's will spare the negation `~a`), and removing the test for `a != 0` (no need to treat this specifically). Immediate `return` after every `if` condition will also spare unconditional jumps. –  Feb 03 '14 at 22:48
1

You may reduce it to much easier problem of finding the highest set bit (highest 1), which is actually the same as finding ceil(log2 X).

unsigned int x, y, c, m;
int b;

c = x ^ y;          // xor : 00010111

// now it comes: b = number of highest set bit in c
// perhaps some special operation or instruction exists for that
b = -1;
while (c) {
    b++;
    c = c >> 1;
}                  // b == 4

m = (1 << (b + 1)) - 1;   // creates a mask: 00011111
return x & ~m;    // x AND NOT M
return y & ~m;    // should return the same result

In fact, if you can compute the ceil(log2 c) easily, then just subtract 1 and you have m, without the need for computing b using the loop above.

If you don't have such functionality, simple optimized code, which uses just basic assembly level operations (bit shifts by one bit: <<=1, >>=1) would look like this:

c = x ^ y;        // c == 00010111 (xor)
m = 1;
while (c) {
    m <<= 1; 
    c >>= 1;
}                 // m == 00100000
m--;              // m == 00011111 (mask)
return x & ~m;    // x AND NOT M

This can be compiled to a very fast code, mostly like one or two machine instructions per line.

Tomas
  • 57,621
  • 49
  • 238
  • 373
  • Finding either (the position of the highest set bit), or (a number where the only MSB remains) suffices. (And by the way, use unsigned integer please...) – nhahtdh Feb 03 '14 at 06:23
  • I think it is UB when the resulting mask is `0xFFFFFFFF` (`(1 << 32) - 1`). – Jarod42 Feb 03 '14 at 12:27
  • @Jarod42 b is always < 32: If c is zero and b is -1. –  Feb 03 '14 at 12:39
  • @DieterLücking: In the last code (without `b`), if `c == 0xFFFFFFFF`, don't we have `m <<= 32` (and `c >>= 32`) ? – Jarod42 Feb 03 '14 at 12:48
  • Overflow doesn't mean it's *undefined* - it's in fact pretty well defined. `c` will be zero and `m` will be all zeros after the loop, `m` will be all ones after `m--`, and the result will be 0. – Tomas Feb 03 '14 at 13:10
  • @Tomas 5.8 Shift operators: "The type of the result is that of the promoted left operand. The behavior is undefined if the right operand is negative, or greater than or equal to the length in bits of the promoted left operand." –  Feb 03 '14 at 13:24
1

It's a little ugly, but assuming 8-bit inputs, you can do something like this:

int x = 0xA5; // 1010 0101
int y = 0xB2; // 1011 0010
unsigned d = x ^ y;

int mask = ~(d | (d >> 1) | (d >> 2) | (d >> 3) | (d >> 4) | (d >> 5) | (d >> 6));

int z = x & mask;

We start by computing the exclusive-or of the numbers, which will give a 0 where they're equal, and a 1 where they're different. For your example, that gives:

00010111

We then shift that right and inclusive-or it with itself each of 7 possible bit positions:

00010111
00001011
00000101
00000010
00000001

That gives:

00011111

Which is 0's where the original numbers were equal, and 1's where they were different. We then invert that to get:

11100000

Then we and that with one of the original inputs (doesn't matter which) to get:

10100000

...exactly the result we wanted (and unlike a simple x & y, it'll also work for other values of x and y).

Of course, this can be extended out to an arbitrary width, but if you were working with (say) 64-bit numbers, the d | (d>>1) | ... | (d>>63); would be a little on the long and clumsy side.

Jerry Coffin
  • 476,176
  • 80
  • 629
  • 1,111
  • Nice idea, but is this really optimal? This looks like an optimization at the first sight (you "save a while loop"), but might actually be worse. `>> 5` might actually be a hidden while loop at machine instruction level. And by repeatedly doing `d >> 4`, `d >> 5`, `d >> 6`, you in fact do the same job multiple times. – Tomas Feb 03 '14 at 06:44
  • 1
    @Tomas: With an old enough (or small enough) CPU, a right shift may be linear on the number of bits shifted. Given that he specifies that this is for supercomputing, I think we can safely guess that the CPU will include a barrel shifter, which does arbitrary shifts in constant time (introduced in the first supercomputer, the CDC 6600, in 1964). – Jerry Coffin Feb 03 '14 at 07:11
  • 1
    @Tomas: As to whether it's really an optimization: my guess is yes. If you were working with more than 8 bits, you could use a loop instead of repeated code, but it'd still be an optimization, because the loop executes a fixed number of times, meaning the branch is completely predictable. A branch like `while (c)` is likely to be predicted incorrectly fairly frequently. A mis-predicted branch (alone) could take longer than the whole operation I've outlined above. – Jerry Coffin Feb 03 '14 at 07:18