4

Find minimum N such that sum of Set bits of numbers from 1 to N is at least k.

For example

k = 11, output N = 7, as SB(1) + SB(2) + .. +SB(7) = 12
k = 5, output N = 4, as SB(1) + SB(2) +..+SB(4) = 5

I thought of solving it first by storing the sum of set bit then applying binary search for atleast k. But, the problem here is that 1 <= k <= 10^18. So obviously DP can't be used. Then how this problem can be solved. The time limit is 1.5 sec(s)

Brij Raj Kishore
  • 1,595
  • 1
  • 11
  • 24

3 Answers3

1

Let's say your number is 13. In binary, it is 1101. Let's table it and see if we can see patterns. I'll just insert some line breaks to help later. Also, I'll add 0, because it doesn't hurt (it has no set bits).

0000
0001
0010
0011
0100
0101
0110
0111

1000
1001
1010
1011

1100

1011

We can write all the groups under n like this:

             000
             001
             010
             011
             100
             101
             110
             111

1000 +       00
1000 +       01
1000 +       10
1000 +       11

1000 + 100 + |    (no digits, equal to 0 in sum)

Notice that we have a group of size 2^3 for the 3rd bit of n=1101; we have a group of size 2^2 for the 2nd bit; and a group of size 2^0 for the LSB. Let's call the group size s=2^b, where b is position of all set bits in n (i.e. b=[3, 2, 0], s=[8, 4, 1]). Notice bit patterns for the rightmost summand in each group: there's b columns of bits, and in each column exactly s/2 are set; thus, for each set bit, there are 2^(b-1)*b set bits in the rightmost columns. But each row also has number of set bits equal to ordinal number of the group: 0, 1, 2 (as we add groups that correspond to set bits in n), for the total of 2^b*i + 2^(b-1)*b set bits per group:

Group 0: 2^3*0 + 2^2*3 = 12
Group 1: 2^2*1 + 2^1*2 = 8
Group 2: 2^0*2 + 2^(-1)*0 = 2

This is all for number of set bits up to n-1. To get number of bits for n, we need to get one more bit for each bit set in n itself - which is exactly the number of groups we had. The total is thus 12+8+2 +3 = 25.

Here it is in Ruby. Note that 2^x * y is identical to y << x.

def sum_bits_upto(n)
  set_bits = n.to_s(2).reverse.each_char.with_index.map { |c, b|
    b if c == "1"
  }.compact.reverse

  set_bits.each_with_index.sum { |b, i|
    (i << b) + (b << b - 1) + 1
  }
end

Hopefully I haven't messed up my logic anywhere. BTW, my code says there's 29761222783429247000 bits from 1 to 1_000_000_000_000_000_000, with just 24 iterations of the loop. That should be fast enough :)

EDIT: Apparently I have goldfish memory. The sum is monotonously increasing (with each successive number, there is a positive number of bits added to the running total). That means, we can use binary search, which should zero in on the target superquick, even if we don't help it by storing interim results (this executes in 0.1s on my Mac):

max = 1_000_000_000_000_000_000_000_000_000
k = 1_000_000_000_000_000_000
n = (1..max).bsearch { |n|
  sum_bits_upto(n) >= k
}
# => 36397208481162321

There has to be a nice formula to calculate the theoretically possible max n to search for based on k, but I couldn't be bothered, so I just put in something big enough. bsearch is that good.

EDIT2: couple of tweaks for the bsearch condition, messed it up at first.

Amadan
  • 191,408
  • 23
  • 240
  • 301
1

Posting my answer way after Amadan's one, because it brings a different point of view regarding the calculation of the number of bits set up to N ; the resolution of the problem is going via a binary search which is appropriate since the bits set calculation is inexpensive


Let see for N being a power of 2, like 8
dcba
----
 000
 001
 010
 011
 100
 101
 110
 111
1000 

In column a we have alternatively 1 and 0, in column b the same but every two (21) numbers, and in c every four (22) numbers.

Anyway, we get the same number of 1 in each column, N/2. Thus the number of 1 up to a power of 2 is (+1 for the power of 2 itself)

log2(N) * N/2 + 1

Any integer being a sum of power of 2, like for 13

1000 + 0100 + 0001

the number of 1 up to N is the sum of the above equation for each of all 2 powers of N, adding the 1s on the left side for each power x = 2P (since to count up to that power, the higher bits of powers > P are there)

bitsets = P * x/2 + 1 + x * number of bits set on the left to that power x

For instance for 1310

1000 => 3 x 8/2 +1 = 13  + 0      (no 1 left)
0100 => 2 x 4/2 +1 =  5  + 4 x 1  (one bit on the left, the 8)
0001 => 0 x 1/2 +1 =  1  + 1 x 2  (the 8 and the 4)

There are 25 1 up to 13.

The calculation is O(log(N)) which is fast, even for 1018 (about 60 iterations).

A binary search will work in O(log2), finding the max k being the number of bits set from 1 up to the power of 2 above 1018, than can be calculated thanks to the formula above.

Déjà vu
  • 28,223
  • 6
  • 72
  • 100
1

Liked the question, so spent some time in coding it. Python code below works with bit positions so the complexity is limited to number of max bits present in 10^18.

# Store sum of 1-bits upto max number formed by N bits.
# For example - sumToNBits of 1 bit is 1, 2 bit numbers 01,10,11 is 4 and 3 bit 
# numbers 01,10,11,100,101,110,111 is 12
# and so on.

sumToNBits = []
prevSum = 0
for i in range(1, 65):
    prevSum = (1 << (i-1)) + 2*prevSum
    sumToNBits.append(prevSum)


# Find maximum possible number (2 to power P - 1) which has sum of 1-bits up to K.
def findClosestPowerTwo(K):
    index = 1
    prevEntry = 0
    for entry in sumToNBits: 
        if (entry > K):
            return (K-prevEntry, index-1)
        prevEntry = entry
        index += 1

    return (K-prevEntry, index-1)

# After finding max 2 to power P, now increase number up to K1 = K - bits used by 2 to power P.
def findNextPowerTwo(K, onBits):
    index = 1
    prevTotalBits = 0
    totalBits = onBits * ((1 << index) - 1) + sumToNBits[index-1]

    while(K >= totalBits):
        prevTotalBits = totalBits
        index += 1
        totalBits = onBits * ((1 << index) - 1) + sumToNBits[index-1]

    return (K-prevTotalBits, index-1)



def findClosestNumber(K):
    (K, powerTwo) = findClosestPowerTwo(K)
    number = (1 << (powerTwo)) - 1

# Align to 2 to power P
    if (K >= 1):
        K = K - 1
        number += 1
    onBits = 1

    while(K > 0):
        (K, powerTwo) = findNextPowerTwo(K, onBits)

        if (powerTwo == 0):
            return number+1

        number += ((1 << (powerTwo)) - 1)

# Align to 2 to power P
        if (K >= (onBits + 1)):
            K = K - (onBits + 1)
            number += 1
        onBits += 1

    return number

num = 1
while num < 100:
    print(num, end = " ")
    print(findClosestNumber(num))
    num += 1
Ajay Srivastava
  • 1,151
  • 11
  • 15