2

Edit: I wish SO let me accept 2 answers because neither is complete without the other. I suggest reading both!

I am trying to come up with a fast implementation of a function that given an unsigned 32-bit integer x returns the sum of 2^trailing_zeros(i) for i=1..x-1, where trailing_zeros is the count trailing zeros operation which is defined as returning the 0 bits after the least significant 1 bit. This seems like the kind of problem that should lend itself to a clever bit manipulation implementation that takes the same number of instructions regardless of the input, but I haven't been able to derive it.

Mathematically, 2^trailing_zeros(i) is equivalent to the largest factor of 2 that exactly divides i. So we are summing those largest factors for 1..x-1.

i                   | 1    2     3    4    5    6    7    8    9    10
-----------------------------------------------------------------------
2^trailing_zeroes(i) | 1    2     1    4    1    2    1    8    1    2
-----------------------------------------------------------------------
Sum (desired value) | 0    1     3    4    8    9    11   12   20   21

It is a little easier to see the structure of 2^trailing_zeroes(i) if we 'plot' the values -- horizontal position increasing from left to right corresponding to i and vertical position increasing from top to bottom corresponding to trailing_zeroes(i).

 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2 
    4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4       4   
        8               8               8               8               8               8               8               8               8               8               8               8               8               8               8               8       
                16                               16                               16                               16                               16                               16                               16                               16               
                                32                                                               32                                                               32                                                               32                               
                                                                64                                                                                                                               64                                                               

Here it is easier to see the pattern that 2's are always 4 apart, 8's are always 16 apart, etc. However, each pattern starts at a different time -- 8's don't begin until i=8, 16 doesn't begin until i=16, etc. If you don't take into account that the patterns don't start right away you can come up with formulas that don't work -- for example you might think to determine the number of 8's going into the total you should just compute floor(x/16) but i=25 is far enough to the right to include both of the first two 8s.

The best solution I have come up with so far is:

  • Set n = floor(log2(x)). This can be computed quickly using the count leading zeros operation. This tells us the highest power of two that is going to be involved in the sum.
  • Set sum = 0
  • for i = 1..n
    • sum += floor((x - 2^i) / 2^(i+1))*2^i + 2^i

The way this works as for each power, it calculates the horizontal distance on the plot between x and the first appearance of that power, e.g. the distance between x and the first 8 is (x-8), and then it divides by the distance between repeating instances of that power, e.g. floor((x-8)/16), which gives us how many times that power appeared, we the sum for that power, e.g. floor((x-8)/16)*8. Then we add one instance of the given power because that calculation excludes the very first time that power appears.

In practice this implementation should be pretty fast because the division/floor can be done by right bit shift and powers of two can be done with 1 bit-shifted to the left. However it seems like it should still be possible to do better. This implementation will loop more for larger inputs, up to 32 times (it's O(log2(n)), ideally we want O(1) without a gigantic lookup table using up all the CPU cache). I've been eyeing the BMI/BMI2 intrinsics but I don't see an obvious way to apply them.

Although my goal is to implement this in a compiled language like C++ or Rust with real bit shifting and intrinsics, I've been prototyping in Python. Included below is my script that includes the implementation I described, z(x), and the code for generating the plot, tower(x).

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from math import pow, floor, log, ceil

def leading_zeros(x):
    return len(bin(x).split('b')[-1].split('1')[-1])

def f(x):
    s = 0
    for c, i in enumerate(range(1,x)):
        a = pow(2, len(bin(i).split('b')[-1].split('1')[-1]))
        s += a
    return s

def g(x): return sum([pow(2,i)*floor((x+pow(2,i)-1)/pow(2,i+1)) for i in range(0,32)])

def h(x):
    s = 0
    extra = 0
    extra_s = 0
    for i in range(0,32):
        num = (x+pow(2,i)-1)
        den = pow(2,i+1)
        fraction = num/den
        floored = floor(num/den)
        power = pow(2,i)
        product = power*floored
        if product == 0:
            break
        s += product
        extra += (fraction - floored)
        extra_s += power*fraction
        #print(f"i={i} s={s} num={num} den={den} fraction={fraction} floored={floored} power={power} product={product} extra={extra} extra_s={extra_s}")
    return s

def z(x):
    upper_bound = floor(log(x,2)) if x > 0 else 0
    s = 0
    for i in range(upper_bound+1):
        num = (x - pow(2,i))
        den = pow(2,i+1)
        fraction = num/den
        floored = floor(fraction)
        added = pow(2,i)
        s += floored * added
        s += added
        print(f"i={i} s={s} upper_bound={upper_bound} num={num} den={den} floored={floored} added={added}")
    return s
#    return sum([floor((x - pow(2,i))/pow(2,i+1) + pow(2,i)) for i in range(floor(log(x, 2)))])

def tower(x):
    table = [[" " for i in range(x)] for j in range(ceil(log(x,2)))]
    for i in range(1,x):
        p = leading_zeros(i)
        table[p][i] = 2**p
    for row in table:
        for col in row:
            print(col,end='')
        print()


# h(9000)
for i in range(1,16):
    tower(i)
    print((i, f(i), g(i), h(i), z(i-1)))
Joseph Garvin
  • 20,727
  • 18
  • 94
  • 165
  • 2
    Doesn't this actually use 'count *trailing* zeroes' (at least that's what BMI calls it) E: also, this sequence is [A006520](https://oeis.org/A006520) on OEIS but the definitions given there don't look very useful from a computational point of view – harold Jun 05 '21 at 21:48
  • @harold You're right I always get them mixed up, corrected – Joseph Garvin Jun 05 '21 at 21:53
  • 1
    Related: [How can I get the value of the least significant bit in a number?](https://stackoverflow.com/q/18806481/2402272) – John Bollinger Jun 05 '21 at 21:59
  • @harold: Well, the recursive formula `a(n) = b(n+1), with b(2n) = 2b(n) + n, b(2n+1) = 2b(n) + n + 1` from OEIS will give you an answer in `log(x)` steps, which is pretty good. And if you have to do it a lot, you can memoize. – Nate Eldredge Jun 05 '21 at 22:59

2 Answers2

4

Observe that if we count from 1 to x instead of to x−1, we have a pattern:

x sum sum/x
1 1 1
2 3 1.5
4 8 2
8 20 2.5
16 48 3

So we can easily calculate the sum for any power of two p as p • (1 + ½b), where b is the power (equivalently, the number of the bit that is set or the log2 of the power). We can see this by induction: If the sum from 1 to 2b is 2b•(1+½b) (which it is for b=0), then the sum from 1 to 2b+1 reprises the individual term contributions twice except that the last term adds 2b+1 instead of 2b, so the sum is 2•2b•(1+½b) − 2b + 2b+1 = 2b+1•(1+½b) + ½•2b+1 = 2b+1•(1+½(b+1)).

Further, between any two powers of two, the lower bits reprise the previous partial sums. Thus, for any x, we can compute the cumulative number of trailing zeros by summing the sums for the set bits in it. Recalling this provides the sum for numbers from 1 to x, we adjust by to get the desired sum from 1 to x−1 subtracting one from x before computation:

unsigned CountCumulative(unsigned x)
{
    --x;
    unsigned sum = 0;
    for (unsigned bit = 0; bit < sizeof x * CHAR_BIT; ++bit)
        sum += (x & 1u << bit) * (1 + bit * .5);
    return sum;
}

We can terminate the loop when x is exhausted:

unsigned CountCumulative(unsigned x)
{
    --x;
    unsigned sum = 0;
    for (unsigned bit = 0; x; ++bit, x >>= 1)
        sum += ((x & 1) << bit) * (1 + bit * .5);
    return sum;
}

As harold points out, we can factor out the 1, as summing the value of each bit of x equals x:

unsigned CountCumulative(unsigned x)
{
    --x;
    unsigned sum = x;
    for (unsigned bit = 0; x; ++bit, x >>= 1)
        sum += ((x & 1) << bit) * bit * .5;
    return sum;
}

Then eliminate the floating-point:

unsigned CountCumulative(unsigned x)
{
    unsigned sum = --x;
    for (unsigned bit = 0; x; ++bit, x >>= 1)
        sum += ((x & 1) << bit) / 2 * bit;
    return sum;
}

Note that when bit is zero, ((x & 1) << bit) / 2 will lose the fraction, but this irrelevant as * bit makes the contribution zero anyway. For all other values of bit, (x & 1) << bit is even, so the division does not lose anything.

This will overflow unsigned at some point, so one might want to use a wider type for the calculations.

More Code Golf

Another way to add half the values of the bits of x repeatedly depending on their bit position is to shift x (to halve its bit values) and then add that repeatedly while removing successive bits from low to high:

unsigned CountCumulative(unsigned x)
{
    unsigned sum = --x;
    for (unsigned bit = 0; x >>= 1; ++bit)
        sum += x << bit;
    return sum;
}
Eric Postpischil
  • 195,579
  • 13
  • 168
  • 312
4

Based on the method of Eric Postpischil, here is a way to do it without a loop.

Note that every bit is being multiplied by its position, and the results are summed (sort of, except there is also a factor of 0.5 in it, let's put that aside for now). Let's call those values that are being added up "the partial products" just to call them something, it's not really accurate to call them that, I can't come up with anything better. If we transpose that a little bit, then it's built up like this: the lowest bit of every partial product is the lowest bit of the position of every bit multiplied by that bit. Single-bit-products are bitwise-AND, and the values of the lowest bits of the positions are 0,1,0,1 etc, so it works out to x & 0xAAAAAAAA, the second bit of every partial product is x & 0xCCCCCCCC (and has a "weight" of 2, so this must be multiplied by 2) etc.

Then the whole thing needs to be shifted right by 1, to account for the factor of 0.5

So in total:

unsigned CountCumulativeTrailingZeros(unsigned x)
{
    --x;
    unsigned sum = x;
    sum += (x >> 1) & 0x55555555;
    sum += x & 0xCCCCCCCC;
    sum += (x & 0xF0F0F0F0) << 1;
    sum += (x & 0xFF00FF00) << 2;
    sum += (x & 0xFFFF0000) << 3;
    return sum;
}

For an additional explanation, here is a more visual example. Let's temporarily drop the factor of 0.5 again, it doesn't fundamentally change the algorithm but adds some complication.

First I write above every bit of v (some example value), the position of that bit in binary (p0 is the least significant bit of the position, p1 the second bit etc). Read the ps vertically, every column is a number:

p0: 10101010101010101010101010101010
p1: 11001100110011001100110011001100
p2: 11110000111100001111000011110000
p3: 11111111000000001111111100000000
p4: 11111111111111110000000000000000
v : 00000000100001000000001000000000

So for example bit 9 is set, and it has (reading from bottom to top) 01001 above it (9 in binary).

What we want to do (why this works has been explained by Eric's answer), is take the indexes of the bits that are set, shift them to their corresponding positions, and add them. In this case, they are already at their own positions (by construction, the numbers were written at their own positions), so there is no shift, but they still need to be filtered so only the numbers that correspond to set bits survive. This is what I meant by the "single bit products": take a bit of v and multiply it by the corresponding bits of p0, p1, etc.

You can look at that as multiplying the bit value by its index as well so 2^bit * bit as mentioned in the comments. That is not how it is done here, but that is effectively what is done.

Back to the example, applying bitwise-AND results in these partial products:

pp0: 00000000100000000000001000000000
pp1: 00000000100001000000000000000000
pp2: 00000000100000000000000000000000
pp3: 00000000000000000000001000000000
pp4: 00000000100001000000000000000000
v  : 00000000100001000000001000000000

The only values that are left are 01001, 10010, 10111, and they are at their corresponding positions (so, already shifted to where they need to go).

Those values must be added, while keeping them at their positions. They don't need to be extracted from the strange form which they are in, addition is freely reorderable (associative and commutative) so it's OK to add all the least significant bits of the partial products to the sum first, then all the seconds bits, and so on. But they have to added with the right "weight", after all a set bit in pp0 corresponds to a 1 at that position but a set bit in pp1 really corresponds to a 2 at that position (since it's the second bit of the number that it is part of). So pp0 is used directly, but pp1 is shifted left by 1, pp2 is shifted left by 2, etc.

The the factor of 0.5 must still be accounted for, which I did mostly by shifting over the bits of the partial products by one less than what their weight would imply. pp0 was shifted left by zero, so it must be shifted right by 1 now. This could be done with less complication by just putting return sum >> 1; at the end, but that would reduce the range of values that the function can handle before running into integer wrapping modulo 232 (also it would cost an extra operation, and doing it the weird way does not).

harold
  • 61,398
  • 6
  • 86
  • 164
  • This expanded explanation helped a ton, thanks! Is there a reason you did ‘(x >> 1) & 0x55555555)’ instead of ‘(x & 0xAAAAAAAA) >> 1’? Breaks the pattern you use on subsequent lines. – Joseph Garvin Jun 08 '21 at 00:27
  • @JosephGarvin not really a specific reason, it was just the first thing I came up with, `(x & 0xAAAAAAAA) >> 1` may look nicer – harold Jun 08 '21 at 00:34