Yes, we can do it in O(b log b) time by developing that idea further.
With an exponential search.
Note that by cutting off each streak's top 1-bit, that also widens the gaps between streaks. Originally, streaks were separated by at least one 0-bit. After cutting off each streak's first 1-bit, streaks are now separated by at least two 0-bits.
We can then do n &= n >> 2
to cut off the first two 1-bits of all remaining streaks. Which also widens the gaps to at least four 0-bits.
We continue cutting off 4, 8, 16, 32, etc 1-bits from the start of each streak, as long as we still have any 1-bit streaks remaining.
Let's say when trying to cut off 32, we find that we have no streak left. At this point we switch to reverse mode. Try cutting off 16 instead. Then 8, 4, 2, and finally 1. But only keep the cuts that still leave us a streak.
Since we only kept cuts that left some streak, we end up with a streak (or streaks) of length 1, so we add that to the total.
The code:
def onebits_linearithmic(n):
if n == 0:
return 0
total = 0
cut = 1
while m := n & (n >> cut):
n = m
total += cut
cut *= 2
while cut := cut // 2:
if m := n & (n >> cut):
n = m
total += cut
return total + 1
Benchmark with random 1,000,000-bit numbers:
0.60 ± 0.06 ms onebits_linearithmic
1.13 ± 0.08 ms onebits_quadratic
48.57 ± 1.25 ms onebits_linear
I included a linear-time solution using strings, but its hidden constant is much higher, so it's still much slower.
Next, random 100,000-bit numbers with 50,000-bit streak of 1-bits:
0.13 ± 0.05 ms onebits_linearithmic
2.37 ± 0.60 ms onebits_linear
176.23 ± 7.82 ms onebits_quadratic
The quadratic solution indeed became much slower. The other two remain fast, so let's try them with random 1,000,000-bit numbers with 500,000-bit streak of 1-bits:
1.36 ± 0.06 ms onebits_linearithmic
24.69 ± 0.84 ms onebits_linear
Full code (Attempt This Online!):
def onebits_quadratic(n):
ctr = 0
while n:
n &= n >> 1
ctr += 1
return ctr
def onebits_linearithmic(n):
if n == 0:
return 0
total = 0
cut = 1
while m := n & (n >> cut):
n = m
total += cut
cut *= 2
while cut := cut // 2:
if m := n & (n >> cut):
n = m
total += cut
return total + 1
def onebits_linear(n):
return max(map(len, f'{n:b}'.split('0')))
funcs = onebits_quadratic, onebits_linearithmic, onebits_linear
import random
from timeit import repeat
from statistics import mean, stdev
# Correctness
for n in [*range(10000), *(random.getrandbits(1000) for _ in range(1000))]:
expect = funcs[0](n)
for f in funcs[1:]:
if f(n) != expect:
print('fail:', f.__name__)
def test(bits, number, unit, scale, long_streak, funcs):
if not long_streak:
print(f'random {bits:,}-bit numbers:')
else:
print(f'random {bits:,}-bit numbers with {bits//2:,}-bit streak of 1-bits:')
times = {f: [] for f in funcs}
def stats(f):
ts = [t * scale for t in times[f][5:]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} {unit} '
for _ in range(10):
n = random.getrandbits(bits)
if long_streak:
n |= ((1 << (bits//2)) - 1) << (bits//4)
for f in funcs:
t = min(repeat(lambda: f(n), number=number)) / number
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print()
test(1_000_000, 1, 'ms', 1e3, False, funcs)
test(100_000, 1, 'ms', 1e3, True, funcs)
test(1_000_000, 1, 'ms', 1e3, True, funcs[1:])