You're correct to suspect that there's a more efficient way.
Let's start with a slightly simpler subproblem. Absent some really clever
insights, we're going to need to be able to find the number of integers in
[n+1, 2n]
that have exactly k
bits set in their binary representation. To
keep things short, let's call such integers "weight-k
" integers (for motivation for this terminology, look up Hamming weight). We can
immediately simplify our counting problem: if we can count all weight-k
integers in [0, 2n]
and we can count all weight-k
integers in [0, n]
, we can subtract one count
from the other to get the number of weight-k
integers in [n+1, 2n]
.
So an obvious subproblem is to count how many weight-k
integers there are
in the interval [0, n]
, for given nonnegative integers k
and n
.
A standard technique for a problem of this kind is to look for a way to break
it down into smaller subproblems of the same kind; this is one aspect of
what's often called dynamic programming. In this case, there's an easy way of
doing so: consider the even numbers in [0, n]
and the odd numbers in [0, n]
separately. Every even number m
in [0, n]
has exactly the same weight as
m/2
(because by dividing by two, all we do is remove a single zero
bit). Similarly, every odd number m
has weight exactly one more than the
weight of (m-1)/2
. With some thought about the appropriate base cases, this
leads to the following recursive algorithm (in this case implemented in Python,
but it should translate easily to any other mainstream language).
def count_weights(n, k):
"""
Return number of weight-k integers in [0, n] (for n >= 0, k >= 0)
"""
if k == 0:
return 1 # 0 is the only weight-0 value
elif n == 0:
return 0 # only considering 0, which doesn't have positive weight
else:
from_even = count_weights(n//2, k)
from_odd = count_weights((n-1)//2, k-1)
return from_even + from_odd
There's plenty of scope for mistakes here, so let's test our fancy recursive
algorithm against something less efficient but more direct (and, I hope, more
obviously correct):
def weight(n):
"""
Number of 1 bits in the binary representation of n (for n >= 0).
"""
return bin(n).count('1')
def count_weights_slow(n, k):
"""
Return number of weight-k integers in [0, n] (for n >= 0, k >= 0)
"""
return sum(weight(m) == k for m in range(n+1))
The results of comparing the two algorithms look convincing:
>>> count_weights(100, 5)
11
>>> count_weights_slow(100, 5)
11
>>> all(count_weights(n, k) == count_weights_slow(n, k)
... for n in range(1000) for k in range(10))
True
However, our supposedly fast count_weights
function doesn't scale well to the
size numbers you need:
>>> count_weights(2**64, 5) # takes a few seconds on my machine
7624512
>>> count_weights(2**64, 6) # minutes ...
74974368
>>> count_weights(2**64, 10) # gave up waiting ...
But here's where a second key idea of dynamic programming comes in: memoize!
That is, keep a record of the results of previous calls, in case we need to use
them again. It turns out that the chain of recursive calls made will tend to
repeat lots of calls, so there's value in memoizing. In Python, this is
trivially easy to do, via the functools.lru_cache
decorator. Here's our new
version of count_weights
. All that's changed is the extra line at the top:
@lru_cache(maxsize=None)
def count_weights(n, k):
"""
Return number of weight-k integers in [0, n] (for n >= 0, k >= 0)
"""
if k == 0:
return 1 # 0 is the only weight-0 value
elif n == 0:
return 0 # only considering 0, which doesn't have positive weight
else:
from_even = count_weights(n//2, k)
from_odd = count_weights((n-1)//2, k-1)
return from_even + from_odd
Now testing on those larger examples again, we get results much more quickly,
without any noticeable delay.
>>> count_weights(2**64, 10)
151473214816
>>> count_weights(2**64, 32)
1832624140942590534
>>> count_weights(5853459801720308837, 27)
356506415596813420
So now we have an efficient way to count, we've got an inverse problem to
solve: given k
and m
, find an n
such that count_weights(2*n, k) -
count_weights(n, k) == m
. This one turns out to be especially easy, since the
quantity count_weights(2*n, k) - count_weights(n, k)
is monotonically
increasing with n
(for fixed k
), and more specifically increases by either
0
or 1
every time n
increases by 1
. I'll leave the proofs of those
facts to you, but here's a demo:
>>> for n in range(10, 30): print(n, count_weights(n, 3))
...
10 1
11 2
12 2
13 3
14 4
15 4
16 4
17 4
18 4
19 5
20 5
21 6
22 7
23 7
24 7
25 8
26 9
27 9
28 10
29 10
This means that we're guaranteed to be able to find a solution. There may be multiple solutions, so we'll aim to find the smallest one (though it would be equally easy to find the largest one). Bisection search gives us a crude but effective way to do this. Here's the code:
def solve(m, k):
"""
Find the smallest n >= 0 such that [n+1, 2n] contains exactly
m weight-k integers.
Assumes that m >= 1 (for m = 0, the answer is trivially n = 0).
"""
def big_enough(n):
"""
Target function for our bisection search solver.
"""
diff = count_weights(2*n, k) - count_weights(n, k)
return diff >= m
low = 0
assert not big_enough(low)
# Initial phase: expand interval to identify an upper bound.
high = 1
while not big_enough(high):
high *= 2
# Bisection phase.
# Loop invariant: big_enough(high) is True and big_enough(low) is False
while high - low > 1:
mid = (high + low) // 2
if big_enough(mid):
high = mid
else:
low = mid
return high
Testing the solution:
>>> n = solve(5853459801720308837, 27)
>>> n
407324170440003813446
Let's double check that n
:
>>> count_weights(2*n, 27) - count_weights(n, 27)
5853459801720308837
Looks good. And if we got our search right, this should be the smallest
n
that works:
>>> count_weights(2*(n-1), 27) - count_weights(n-1, 27)
5853459801720308836
There are plenty of other opportunities for optimizations and cleanups in the
above code, and other ways to tackle the problem, but I hope this gives you a
starting point.
The OP commented that they needed to do this in C, where memoization isn't immediately available without using an external library. Here's a variant of count_weights
that doesn't need memoization. It's achieved by (a) tweaking the recursion in count_weights
so that the same n
is used in both recursive calls, and then (b) returning, for a given n
, the values of count_weights(n, k)
for all k
for which the answer is nonzero. In effect, we're just moving the memoization into an explicit list.
Note: as written, the code below needs Python 3.
def count_all_weights(n):
"""
Return frequencies of weights of all integers in [0, n],
as a list. The kth entry in the list gives the count
of weight-k integers in [0, n].
Example
-------
>>> count_all_weights(16)
[1, 5, 6, 4, 1]
"""
if n == 0:
return [1]
else:
wm = count_all_weights((n-1)//2)
weights = [wm[0], *(wm[i]+wm[i+1] for i in range(len(wm)-1)), wm[-1]]
if n % 2 == 0:
weights[bin(n).count('1')] += 1
return weights
An example call:
>>> count_all_weights(7590)
[1, 13, 78, 286, 714, 1278, 1679, 1624, 1139, 559, 182, 35, 3]
This function should be good enough even for larger n
: count_all_weights(10**18)
takes less than a half a millisecond on my machine.
Now the bisection search will work as before, replacing the call to count_weights(n, k)
with count_all_weights(n)[k]
(and similarly for count_weights(2*n, k)
).
Finally, another possibility is to break up the interval [0, n]
into a succession of smaller and smaller subintervals, where each subinterval has length a power of two. For example, we'd break the interval [0, 101]
into [0, 63]
, [64, 95]
, [96, 99]
and [100, 101]
. The advantage of this is that we can easily compute how many weight-k
integers there are in any one of these subintervals by counting combinations. For example, in [0, 63]
we have all possible 6-bit combinations, so if we're after weight-3 integers, we know there must be exactly 6-choose-3 (i.e., 20) of them. And in [64, 95]
, we know each integer starts with a 1
-bit, and then after excluding that 1
-bit we have all possible 5-bit combinations, so again we know how many integers there are in this interval with any given weight.
Applying this idea, here's a complete, fast, all-in-one function that solves your original problem. It has no recursion and no memoization.
def solve(m, k):
"""
Given nonnegative integers m and k, find the smallest
nonnegative integer n such that the closed interval
[n+1, 2*n] contains exactly m weight-k integers.
Note that for k small there may be no solution:
if k == 0 then we have no solution unless m == 0,
and if k == 1 we have no solution unless m is 0 or 1.
"""
# Deal with edge cases.
if k < 2 and k < m:
raise ValueError("No solution")
elif k == 0 or m == 0:
return 0
k -= 1
# Find upper bound on n, and generate a subset of
# Pascal's triangle as we go.
rows = []
high, row = 1, [1] + [0] * k
while row[k] < m:
rows.append((high, row))
high, row = high * 2, [1, *(row[i]+row[i+1] for i in range(k))]
# Bisect to find first n that works.
low = mlow = weight = 0
while rows:
high, row = rows.pop()
mmid = mlow + row[k - weight]
if mmid < m:
low, mlow, weight = low + high, mmid, weight + 1
return low + 1