2

Problem

Given N, return M that satisfy the equation: N + M = 2 * (N ^ M)

Constraints

1 <= Test Cases = 10^5; 
1 <= N <= 10^18

I came across this problem in one of the hiring challenges.

By trial and error method, I have found a pattern that - Such an M exists between N/3 and 3N and that N + M is an Even number. So I code it up and upon submission, my solution only managed to pass only half of the test cases. This is not much of an optimisation as this method's time complexity is same as that of Brute force solution.

I know that my solution is not the Optimal solution.

Here's my solution:

def solve(n):
    m = n//3
    end = 3*n
    
    # If both m and n are odd/even, their sum will be even
    if (m&1 == 1 and n & 1 == 1) or (m&1 == 0 and n&1 == 0):
        inc = 2
    else:
        m += 1
        inc = 2

    while m <= end:
        if (n + m) == 2 * (n ^ m):
            return m

        m += inc

Could someone provide me some hints/methods/algorithm to get an Optimal Solution. Thanks!

Ram
  • 4,724
  • 2
  • 14
  • 22
  • `N + M` is a even number – shdxiang Jul 12 '21 at 08:21
  • @shdxiang Yes it is. I even tried that but I think that is not a major optimization as we are only reducing the input by half. You can check my updated code. – Ram Jul 12 '21 at 08:41
  • Can you solve for the bottom (lowest) bit? Given that lowest bit, can you solve for the next lowest bit? Etc. – Paul Hankin Jul 12 '21 at 08:47
  • That looks like an equation that can be solved using the Lambert W function, and then rearrange for N = ... Sadly I'm not good enough in maths to help you more than that – Plopp Jul 12 '21 at 08:56

2 Answers2

6

The bottom bit of m is determined (since n+m must be even). Given that bottom bit, the next bit is determined, and so on.

That observation leads to this O(log n) solution:

def solve(n):
    b = 1
    m = 0
    while n + m != 2 * (n ^ m):
        mask = 2 * b - 1
        if ((n + m) & mask) != ((2 * (n ^ m)) & mask):
            m += b
        b *= 2
    return m

Another way to implement this is to find the smallest bit in which m+n and 2*(n^m) differ, and toggle that bit in m. That results in this very compact code (using the new walrus operator, and some bit-twiddling tricks):

def solve(n):
    m = 0
    while r := n + m ^ 2 * (n ^ m):
        m |= r & -r
    return m
Paul Hankin
  • 54,811
  • 11
  • 92
  • 118
0

I haven't tested this soln, it might not work, but here we go

it is known that

N+M=(N^M)+(N&M)*2

but we are given N+M=2(N^M)

by the above two equations, we get

*2(N&M)=(N^M)*

here multiplying with 2 means we are just left shifting the value , so if we get a number such as

1 0 1 0 =N^M
|

0 1 0 1 =N&M

the above soln will be satisfied

we know the last bit of N and M , as the RHS will always be a even , the last bit of M will be same as N( __0 for even and __1 for odd)

let us assume N is odd => _ _ _ 1

hence we will have M as => _ _ _ _ 1

now let us calculate what xor of these numbers will be

xor=> _ _ _ _ 0
AND=> _ _ _ _ 1

we know that "AND" here should be left-shifted(2*) hence we see that last digit of AND will be the digit for xor at last-1 position , in simple words:

xor=>      e d c b a (0/1)
AND=>(0/1) e d c b a

we can write code for this:

below is java code:

        int n=sc.nextInt();
        String nBi="0"+Integer.toBinaryString(n);
        StringBuilder mBi=new StringBuilder("");
        int target;
        if(n%2==0){
            mBi.append("0");
            target=0;
        }else{
            mBi.append("1");
            target=1;
        }
        int length=nBi.length();
        for(int i=length-2;i>=0;i--){
            //target for xor is saved in target variable 
            if(target==0){
                //same numbers causes 0 in xor
                mBi.append(nBi.charAt(i));
                target=(int)nBi.charAt(i)&1;
            }else{
                mBi.append(nBi.charAt(i)=='1'?'0':'1');
                target=0;
            }
        }
        int m=Integer.parseInt(mBi.reverse().toString(),2);
        System.out.println(m);  
        System.out.println((n+m)==2*(n^m));  
toyota Supra
  • 3,181
  • 4
  • 15
  • 19
  • That sounds like a reasonable technique (I didn't test it either), but did you really have to implement it with strings like that.. Surely it is possible to implement this technique nicely? – harold Aug 08 '23 at 12:58
  • surely it can be done harold! there is always a better way to code a solution, I have used strings here because I was trying to push computations as far as possible, also we can append new bits at the start as well, but because of using strings, I was able to use inbuilt functions such as parseInt or convert to binary rather than doing math and adding some value in each iteration.@harold – Jagnath reddy Aug 08 '23 at 13:41