3

Keypad Sticky Note

The minions have some of Professor Boolean's secrets safely locked away. Or so they think. In fact, they are so confident, they even have a password hint sticky note stuck to the keypad of the lock.

The lock requires you to enter a pair of non-negative integers (a, b) into the keypad. Since the integers could be as big as 2 billion, you look to the sticky note for help.

The sticky note has two numbers written on it, but even the minions know enough not to put the passwords there. They have actually written down the sum (they have it labelled as s) and the bitwise exclusive or (xor, labelled as x) of the pair of password integers (a, b) instead. That way, they only need to remember one. If they have difficulty with subtraction, they can use the bitwise exclusive or.

i.e., we have that s = a+b and x = a^b (where ^ is the bitwise XOR operation).

With your automated hacking equipment, each attempt to input a guess takes a few milliseconds. Since you only have a little time before you are discovered, you want to know how long it might take before you are able to try all the combinations. Thanks to the sticky note, you now can eliminate certain combinations without even having to input them into the keypad, and you can find out exactly how long it might take to crack the lock - in the worst case scenario.

Write a function called answer(s, x) that finds the number of pairs (a, b) that have the target sum and xor.

For example, if s=10 and x=4, then the possible values for (a, b) are (3, 7) and (7, 3), so answer would return 2.

If s=5 and x=3, then there are no possible values, so answer would return 0.

s and x are at least 0 and at most 2 billion.

Languages

To provide a Python solution, edit solution.py To provide a Java solution, edit solution.java

Test cases

Inputs: (int) s = 10 (int) x = 4 Output: (int) 2

Inputs: (int) s = 0 (int) x = 0 Output: (int) 1

public static int answer(int s, int x) {
    List<Integer> num = new ArrayList<>();
    int a;
    int b;
    int sum;
    int finalans;

    for(int i = 0; i <=s; i++){
        for(int e = 0; e <= s; e++){
            sum = i + e;
            if(sum == s){
                if((i^e) == x){
                    if(!num.contains(i)){
                        num.add(i);
                    }
                    if(!num.contains(e)){
                        num.add(e);
                    }
                }
            }
        }
    }

    finalans = num.size();
    if((finalans%2) == 0){
        return finalans*2;
    } else if(!((finalans%2) == 0)){
        return finalans;
    }
    return 0;

}

My code works, but it takes too long to long when s and x become too large. How would I make this program run quicker?

amit
  • 175,853
  • 27
  • 231
  • 333
rkatakam
  • 51
  • 6

5 Answers5

1

You can solve this by realizing that for an incoming state (xor digit, sum digit, incoming carry) there are a limited number of outgoing states (outgoing carry). You can address each state with an if condition and use recursion to calculate the total number of combinations. You can use memoization to make the recursion efficient. My solution below solves the problem in O(m) time, where m is the number of binary digits in your number datatype. Since the problem specifies that m = 32 (integers), this is technically an O(1) solution.

Let me know if you have any questions. I tried to add helpful comments in the code to explain the various cases.

public class SumAndXor {
    public static void main(String[] args) {
        int a = 3;
        int b = 7;

        int sum = a + b;
        int xor = a ^ b;

        System.out.println(answer(sum, xor));
    }

    private static final int NOT_SET = -1;

    // Driver
    public static int answer(int sum, int xor) {
        int numBitsPerInt = Integer.toBinaryString(Integer.MAX_VALUE).length() + 1;
        int[][] cache = new int[numBitsPerInt][2];

        for (int i = 0; i < numBitsPerInt; ++i) {
            cache[i][0] = NOT_SET;
            cache[i][1] = NOT_SET;
        }

        return answer(sum, xor, 0, 0, cache);
    }

    // Recursive helper
    public static int answer(int sum, int xor, int carry, int index, int[][] cache) {
        // Return memoized value if available
        if (cache[index][carry] != NOT_SET) {
            return cache[index][carry];
        }

        // Base case: nothing else to process
        if ((sum >> index) == 0 && (xor >> index) == 0 && carry == 0) {
            return 1;
        }

        // Get least significant bits
        int sumLSB = (sum >> index) & 1;
        int xorLSB = (xor >> index) & 1;

        // Recursion
        int result = 0;

        if (carry == 0) {
            if (xorLSB == 0 && sumLSB == 0) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 0 and the incoming carry is 0, both [0, 0] and [1, 1] are valid. We
                // recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
                // 1 to represent [1, 1].
                result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 0 && sumLSB == 1) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 1 and the incoming carry is 0, neither [0, 0] nor [1, 1] is valid.
                result = 0;
            } else if (xorLSB == 1 && sumLSB == 0) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 0 and the incoming carry is 0, neither [0, 1] nor [1, 0] is valid.
                result = 0;
            } else if (xorLSB == 1 && sumLSB == 1) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 1 and the incoming carry is 0, both [0, 1] and [1, 0] is valid. We
                // recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
                // of 0 to represent [1, 0].
                result = 2 * answer(sum, xor, 0, index + 1, cache);
            }
        } else {
            if (xorLSB == 0 && sumLSB == 0) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 0 and the incoming carry is 1, neither [0, 0] nor [1, 1] is valid.
                result = 0;
            } else if (xorLSB == 0 && sumLSB == 1) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 1 and the incoming carry is 1, both [0, 0] and [1, 1] are valid. We
                // recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
                // 1 to represent [1, 1].
                result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 1 && sumLSB == 0) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 0 and the incoming carry is 1, both [0, 1] and [1, 0] are valid. We
                // recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
                // of 0 to represent [1, 0].
                result = 2 * answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 1 && sumLSB == 1) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 1 and the incoming carry is 1, neither [0, 1] nor [1, 0] is valid.
                result = 0;
            }
        }

        cache[index][carry] = result;

        return result;
    }
}
John Kurlak
  • 6,594
  • 7
  • 43
  • 59
1

Google says it's taking too long because your algorithm runs in O(n^2), and Google wants it in O(lg n). If you ask me, this one was way too difficult for a level 3 challenge. I've had easier level 4's. The solution to this is nothing like you'd expect. In fact, you'll never even set values to (a, b) nor compare (a, b) to (S, x) in the correct answer. It's counter to logic until you see and understand the solution.

Anyways, it helps to plot the correct answers in a 2D graph or in an Excel spreadsheet using S for rows and x for columns (leaving zeroes blank). Then, look for patterns. The data points actually form a Sierpinski triangle (see http://en.wikipedia.org/wiki/Sierpinski_triangle).

You'll also notice that every data point (greater than zero) in a column is the same for all instances in that column, so given your x value, you automatically know what the final answer should be so long as the row corresponding to your S value intersects a data point in the triangle. You just need to determine if the S value (row) intersects with the triangle at column x. Make sense?

Even the values in the columns form a pattern from 0 to x: 1, 2, 2, 4, 2, 4, 4, 8, 2, 4, 4, 8, 4, 8, 8, 16... I'm sure you can figure it out.

Here is the "final value given x" method along with most of the remaining code(in Python...Java's too wordy and complicated). You just need to write the triangle traversal algorithm (I'm not giving this away, but this is a solid push in the right direction):

def final(x, t):
    if x > 0:
        if x % 2: # x is odd
            return final(x / 2, t * 2)
        else: # x is even
            return final(x / 2, t)
    else:
        return t

def mid(l, r):
    return (l + r) / 2

def sierpinski_traverse(s_mod_xms, x. lo, hi, e, l, r):
    # you can do this in 16 lines of code to end with...
    if intersect:
        # always start with a t-value of 1 when first calling final in case x=0
        return final(x, 1)
    else:
        return 0

def answer(s, x):
    print final(x, 1)

    if s < 0 or x < 0 or s > 2000000000 or x > 2000000000 or s < x or s % 2 != x % 2:
        return 0
    if x == 0:
        return 1

    x_modulus_size = 2 ** int(math.log(x, 2) + 2)
    s_mod_xms = s % x_modulus_size
    lo_root = x_modulus_size / 4
    hi_root = x_modulus_size / 2
    exp = x_modulus_size / 4    # exponent of 2 (e.g. 2 ** exp)

    return sierpinski_traverse(s_mod_xms, x, lo_root, hi_root, exp, exp, 2 * exp)


if __name__ == '__main__':
    answer(10, 4)
jppodo
  • 19
  • 2
0

Try changing num to a HashSet. You can also clean up your if/else at the end.

e.g.

public static int answer(int s, int x) {
    HashSet<Integer> num = new HashSet<>();
    int a;
    int b;
    int sum;
    int finalans;

    for(int i = 0; i <=s; i++){
        for(int e = 0; e <= s; e++){
            sum = i + e;
            if(sum == s){
                if((i^e) == x){
                    num.add(i);
                    num.add(e);
                }
            }
        }
    }

    finalans = num.size();
    if((finalans%2) == 0){
        return finalans*2;
    } else {
        return finalans;
    }        
}
BillRobertson42
  • 12,602
  • 4
  • 40
  • 57
  • This still exceeds the time limit on the Google Foobar challenge. It didn't really make a difference. – rkatakam May 04 '15 at 12:09
0

Most steps in your algorithm perform too much work:

  • You do a linear scan over all non-negative integers up to s. Since the problem is symmetric, scanning up to s/2 is enough.
  • You do a second linear scan to find for every a another integer b satisfying a + b = s. Simple algebra shows that there is just one such b, which is s - a, so no linear scan is necessary at all.
  • You do a third linear scan to check if you've already found a pair (a, b). If you loop to s/2 only, it will always hold that a &leq; b, and thus you won't suffer from double counting.

Finally, I can think of one simple optimization to save some work:

  • If s is even, then either a and b are both even, or both odd. Therefore, a ^ b is even in that case.
  • If s is odd, either a or b is odd, and thus a ^ b is odd.

You can add that check before performing any work:

public static int answer(int s, int x) {
    int result = 0;
    if (s % 2 == x % 2) {
        for (int a = 0; a <= s / 2; a++) {
            int b = s - a;
            if ((a ^ b) == x) {
                result += 2;
            }
        }
        // we might have double counted the pair (s/2, s/2)
        // decrement the count if needed
        if (s % 2 == 0 && ((s / 2) ^ (s / 2)) == x) {
            result--;
        }
    }
    return result;
}
Vincent van der Weele
  • 12,927
  • 1
  • 33
  • 61
  • It still takes much too long to compile and also it doesn't solve it in some cases. I had to modify this to make program not get an error by the XOR operand, then it worked, it just still took a long time. – rkatakam May 04 '15 at 12:02
  • Those were the exact modifications I made, but your algorithm seems to run quick enough in my IDE, but in the challenge, it gives the prompt that the "time exceeded" – rkatakam May 04 '15 at 13:16
0

To explain my previous answer further, look at the big picture...literally. The triangle traversal algorithm works like a binary search, except with three options instead of two ("trinary" search?). Look at the 3 biggest triangles inside the largest triangle that encompass both S and x. Then, select the triangle of those three that contains S and x. Then, look at the three biggest triangles in the newly selected triangle and pick the one that contains S and x. Repeat until you've arrived at a single point. If that point isn't zero, return the "final" value I specified. There are some if-else statements that will also speed this up in the event that you select a triangle and row S doesn't intersect a data point.

jppodo
  • 19
  • 2