Edit: I wish SO let me accept 2 answers because neither is complete without the other. I suggest reading both!
I am trying to come up with a fast implementation of a function that given an unsigned 32-bit integer x
returns the sum of 2^trailing_zeros(i)
for i=1..x-1
, where trailing_zeros
is the count trailing zeros operation which is defined as returning the 0 bits after the least significant 1 bit. This seems like the kind of problem that should lend itself to a clever bit manipulation implementation that takes the same number of instructions regardless of the input, but I haven't been able to derive it.
Mathematically, 2^trailing_zeros(i)
is equivalent to the largest factor of 2 that exactly divides i
. So we are summing those largest factors for 1..x-1
.
i | 1 2 3 4 5 6 7 8 9 10
-----------------------------------------------------------------------
2^trailing_zeroes(i) | 1 2 1 4 1 2 1 8 1 2
-----------------------------------------------------------------------
Sum (desired value) | 0 1 3 4 8 9 11 12 20 21
It is a little easier to see the structure of 2^trailing_zeroes(i)
if we 'plot' the values -- horizontal position increasing from left to right corresponding to i
and vertical position increasing from top to bottom corresponding to trailing_zeroes(i)
.
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4 4
8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8
16 16 16 16 16 16 16 16
32 32 32 32
64 64
Here it is easier to see the pattern that 2's are always 4 apart, 8's are always 16 apart, etc. However, each pattern starts at a different time -- 8's don't begin until i=8
, 16 doesn't begin until i=16
, etc. If you don't take into account that the patterns don't start right away you can come up with formulas that don't work -- for example you might think to determine the number of 8's going into the total you should just compute floor(x/16)
but i=25
is far enough to the right to include both of the first two 8
s.
The best solution I have come up with so far is:
- Set
n = floor(log2(x))
. This can be computed quickly using the count leading zeros operation. This tells us the highest power of two that is going to be involved in the sum. - Set
sum = 0
- for
i = 1..n
sum += floor((x - 2^i) / 2^(i+1))*2^i + 2^i
The way this works as for each power, it calculates the horizontal distance on the plot between x
and the first appearance of that power, e.g. the distance between x
and the first 8
is (x-8)
, and then it divides by the distance between repeating instances of that power, e.g. floor((x-8)/16)
, which gives us how many times that power appeared, we the sum for that power, e.g. floor((x-8)/16)*8
. Then we add one instance of the given power because that calculation excludes the very first time that power appears.
In practice this implementation should be pretty fast because the division/floor can be done by right bit shift and powers of two can be done with 1 bit-shifted to the left. However it seems like it should still be possible to do better. This implementation will loop more for larger inputs, up to 32 times (it's O(log2(n))
, ideally we want O(1)
without a gigantic lookup table using up all the CPU cache). I've been eyeing the BMI/BMI2 intrinsics but I don't see an obvious way to apply them.
Although my goal is to implement this in a compiled language like C++ or Rust with real bit shifting and intrinsics, I've been prototyping in Python. Included below is my script that includes the implementation I described, z(x)
, and the code for generating the plot, tower(x)
.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from math import pow, floor, log, ceil
def leading_zeros(x):
return len(bin(x).split('b')[-1].split('1')[-1])
def f(x):
s = 0
for c, i in enumerate(range(1,x)):
a = pow(2, len(bin(i).split('b')[-1].split('1')[-1]))
s += a
return s
def g(x): return sum([pow(2,i)*floor((x+pow(2,i)-1)/pow(2,i+1)) for i in range(0,32)])
def h(x):
s = 0
extra = 0
extra_s = 0
for i in range(0,32):
num = (x+pow(2,i)-1)
den = pow(2,i+1)
fraction = num/den
floored = floor(num/den)
power = pow(2,i)
product = power*floored
if product == 0:
break
s += product
extra += (fraction - floored)
extra_s += power*fraction
#print(f"i={i} s={s} num={num} den={den} fraction={fraction} floored={floored} power={power} product={product} extra={extra} extra_s={extra_s}")
return s
def z(x):
upper_bound = floor(log(x,2)) if x > 0 else 0
s = 0
for i in range(upper_bound+1):
num = (x - pow(2,i))
den = pow(2,i+1)
fraction = num/den
floored = floor(fraction)
added = pow(2,i)
s += floored * added
s += added
print(f"i={i} s={s} upper_bound={upper_bound} num={num} den={den} floored={floored} added={added}")
return s
# return sum([floor((x - pow(2,i))/pow(2,i+1) + pow(2,i)) for i in range(floor(log(x, 2)))])
def tower(x):
table = [[" " for i in range(x)] for j in range(ceil(log(x,2)))]
for i in range(1,x):
p = leading_zeros(i)
table[p][i] = 2**p
for row in table:
for col in row:
print(col,end='')
print()
# h(9000)
for i in range(1,16):
tower(i)
print((i, f(i), g(i), h(i), z(i-1)))