0

I've a trouble understanding why the following recursive code (function count()) gives the wrong count of calculations, but the one based on manually written nested for loops (function count2()) gives the right count, which is n! * 4 ^ (n-1)?

(Never mind the output variable at this point. I'll use it later, if I can solve this puzzle first.)

I wish to create a recursive function that can create the calculations for a list of arbitrary length, which is why simply nesting for loops is not good enough.

import itertools
import operator
# http://stackoverflow.com/questions/2983139/assign-operator-to-variable-in-python
ops = {
    0: operator.add,
    1: operator.sub,
    2: operator.mul,
    3: operator.truediv
}
comb = [4, 1, 2, 3]
perms = list()
# itertools.permutations is not subscriptable, so this is a mandatory step.
# See e.g. http://stackoverflow.com/questions/216972/in-python-what-does-it-mean-if-an-object-is-subscriptable-or-not
# for details.
for i in itertools.permutations(comb):
    perms.append(i)
output = list()
output2 = list()

# In theory, there are n! * 4 ^ (n-1) possibilities for each set.
# In practice however some of these are redundant, because multiplication and
# addition are indifferent to calculation order. That's not tested here;
# nor is the possibility of division by zero.

# Variable debug is there just to enable checking the calculation count;
# it serves no other purpose.
debug = list()
debug2 = list()

def count(i):
    for j in range(len(i)):
        for op in ops:
            if j+1 < len(i):
                res = ops[op](i[j], i[j+1])
                if j+2 < len(i):
                    ls = list(i[j+1:])
                    ls[0] = res
                    count(ls)
                else:
                    debug.append([len(i), i[j], ops[op], i[j+1], res])
                    if res == 10: output.append(res)

def count2(i):
    for j in range(len(i)):
        for op in ops:
            if j+1 < len(i):
                res = ops[op](i[j], i[j+1])
                for op2 in ops:
                    if j+2 < len(i):
                        res2 = ops[op2](res, i[j+2])
                        for op3 in ops:
                            if j+3 < len(i):
                                res3 = ops[op3](res2, i[j+3])
                                debug2.append(res3)
                                if res3 == 10: output2.append(res3)

for i in perms:
    count(i)
    count2(i)
print(len(debug)) # The result is 2400, which is wrong.
print(len(debug2)) # The result is 1536, which is correct.

1 Answers1

0

You are appending too many reuslts in your recursive function. Let me elaborate with your example. Let us just consider counting for the original permutation. The call

count([4, 1, 2, 3])

should result in the following recursive calls:

count([5, 2, 3])  # +
count([3, 2, 3])  # -
count([4, 2, 3])  # *
count([4, 2, 3])  # /

And it should not append any results, right! However, since you are looping through the indexes in your top-level call, it additionally results in all of the following recursive calls:

count([3, 3])   # these calls are premature
count([-1, 3])  # since they reduce 1, 2
count([2, 3])   # and do not consider
count([0.5, 3]) # the first element 4

and appends the results of all [2, 3] calculations which is equally too early!

What you really want to do in your recursive function, is to only reduce the first pair of elements (not EVERY pair of adjacent elements!) and then recursively count calculations for the resulting lists. Hence, your function can be simplified to:

def count(lst):
    # no mo' looping through the list. That's why we do recursion!
    for op in ops.values():
        if len(lst) > 1:
            res = op(*lst[:2])  # reduce first pair
            if len(lst) > 2: 
                ls_short = [res] + list(lst[2:])
                count(ls_short)
            else:
                debug.append(res)

...
> print(len(debug)) 
1536
user2390182
  • 72,016
  • 6
  • 67
  • 89