0

I am now working on find the root-to-leaf path with the maximum sum. My approach is as:

def max_sum(root):
    _max = 0
    find_max(root, _max, 0)
    return _max

def find_max(node, max_sum, current_sum):
    if not node:
        return 0
    current_sum += node.value
    if not node.left and not node.right:
        print(current_sum, max_sum, current_sum > max_sum)
        max_sum = max(max_sum, current_sum)
    if node.left:
        find_max(node.left, max_sum, current_sum)
    if node.right:
        find_max(node.right, max_sum, current_sum)
    current_sum -= node.value

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(1)
    root.left = TreeNode(7)
    root.right = TreeNode(9)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(2)
    root.right.right = TreeNode(7)

    print(max_sum(root))

    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(max_sum(root))

main()

with output:

12 0 True
13 0 True
12 0 True
17 0 True
0
23 0 True
23 0 True
18 0 True
0

Process finished with exit code 0

The expected output is 17 and 23.

I would like to confirm why my approach can't compare max_sum and current_sum? Even it returned the true in the comparison, but won't update the max_sum. Thanks for your help.

Qiang Super
  • 323
  • 1
  • 11
  • haven't looked at your code in detail, but having a function and another variable with the same name is a bad idea – turtle Jan 11 '21 at 22:28
  • In `max_sum`, `_max` is set to 0 and then never changed. Therefore, 0 is returned. – mkrieger1 Jan 11 '21 at 22:30
  • Thanks, @turtle. I am not sure which variable name conflicts with the function name? – Qiang Super Jan 11 '21 at 22:32
  • Thanks @mkrieger1. Would you mind explaining more? I though my code is going to update the `_max` in the `find_max` function. – Qiang Super Jan 11 '21 at 22:33
  • @QiangSuper: I mean `max_sum`. Also as the other person has pointed out, the issue is that the variable `max_sum` inside `find_max` is only local to that function. Updating its value doesn't change the value of `_max` – turtle Jan 11 '21 at 22:41
  • Thanks, @turtle. Would you mind sharing what should I to fix this issue, like making _max global? – Qiang Super Jan 11 '21 at 23:14

1 Answers1

1

bugfix

Here's a way we could fix your find_sum function -

def find_max(node, current_sum = 0):
  # empty tree
  if not node:
      return current_sum

  # branch
  elif node.left or node.right:
    next_sum = current_sum + node.value
    left = find_max(node.left, next_sum)
    right = find_max(node.right, next_sum)
    return max(left, right)
  
  # leaf
  else:
    return current_sum + node.value
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(find_max(t1))
print(find_max(t2))  
17
23

seeing the process

We can visualise the computational process by tracing one of the examples, find_max(t2) -

             12
          /       \
         7         1
        / \       / \
       4   None  10  5
     find_max(12,0)
          /      \
         7        1
        / \      / \
       4  None  10  5
          find_max(12,0)
          /           \
max(find_max(7,12), find_max(1,12))
     / \                / \
    4  None           10   5
                                find_max(12,0)
                           /                         \
         max(find_max(7,12),                          find_max(1,12))
            /              \                          /             \
max(find_max(4,19), find_max(None,19))  max(find_max(10,13), find_max(5,13))
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
      /              \                /             \
 max(23,             19)         max(23,            18)
                        find_max(12,0)
                       /              \     
     max(find_max(7,12),              find_max(1,12))
            |                                |
           23                               23
            find_max(12,0)
            /            \     
     max(23,              23)  
            find_max(12,0)
                 |
                23
23

refinements

However I think we can improve. Just like we did in your previous question, we can use mathematical induction again -

  1. if the input tree t is empty, return the empty result
  2. (inductive) t is not empty. if sub-problems t.left or t.right branches are present, add t.value to the accumulated result r and recur on each
  3. (inductive) t not empty and both t.left and t.right are empty; a leaf node has been reached; add t.value to the accumulated result r and yield the sum
def sum_branch (t, r = 0):
  if not t:
    return                                       # (1)
  elif t.left or t.right:
    yield from sum_branch(t.left, r + t.value)   # (2)
    yield from sum_branch(t.right, r + t.value)
  else:
    yield r + t.value                            # (3)
t1 = TreeNode \
  ( 1
  , TreeNode(7, TreeNode(4), TreeNode(5))
  , TreeNode(9, TreeNode(2), TreeNode(7))
  )
  
t2 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), None)
  , TreeNode(1, TreeNode(10), TreeNode(5))
  )

print(max(sum_branch(t1)))
print(max(sum_branch(t2)))
17
23

generics

Perhaps a more interesting way to write this problem is to write a generic paths function first -

def paths (t, p = []):
  if not t:
    return                                     # (1)
  elif t.left or t.right:
    yield from paths(t.left, [*p, t.value])    # (2)
    yield from paths(t.right, [*p, t.value])
  else:
    yield [*p, t.value]                        # (3)

And then we can solve the max sum problem as a composition of generic functions max, sum, and paths -

print(max(sum(x) for x in paths(t1)))
print(max(sum(x) for x in paths(t2)))
17
23
Mulan
  • 129,518
  • 31
  • 228
  • 259
  • Thanks for your help. Would you mind taking a look at my code? My concern is that why the `max_sum` hasn't been updated by the statement `max_sum = max(max_sum, current_sum)`? – Qiang Super Jan 12 '21 at 14:35
  • The issue is with `find_max(node.left, max_sum, current_sum)` and `find_max(node.right, max_sum, current_sum)` where we don't use the return value of these calls. Also `max_sum` is passed by value, not by reference. In your other question this technique worked because `[]` is passed by reference. I updated the top of my answer with a edit of `find_sum`. Does that help? – Mulan Jan 12 '21 at 15:03
  • Thanks for your help. I think I am pretty close to the final answer. But I am still a little confused. Suppose we are at the leaf of the balanced tree, and at this point, `max_sum` is still 0 since haven't been updated. Since `max_sum` only passed its value to the function, all thees leaf node will compare its `current_sum` with `max_sum = 0`? If so, how comes the returned `max_sum` is still 0? – Qiang Super Jan 12 '21 at 15:17
  • I also have another question. In your approach, you return `max(left, right)`. So the logic of the recursion will be something like this: at the leaf, find out the `final_sum`, then back to the internal nodes, compare between the children nodes with the same parent node. Eventually, compare these parent nodes at the root, which is our result. – Qiang Super Jan 12 '21 at 15:20
  • Thank you for your time and patience. – Qiang Super Jan 12 '21 at 15:20
  • Qiang, you are very welcome. I added a section, _Seeing The Process_, that should help you trace the input to the output. a new `current_sum` is passed to each recursive call. And a final note about `max(left,right)`, we use python's built-in `max` function here but we could easily rewrite this as `return left if left > right else right`. – Mulan Jan 12 '21 at 18:46
  • Thank you for your explanation. You algorithm is pretty clear for me now. I still have a concept question in my algorithm. In my algorithm, suppose we are at the leaf of the balanced tree, and at this point, max_sum is still 0 since haven't been updated. Since max_sum only passed its value to the function, all thees leaf node will compare its current_sum with max_sum = 0? If so, how comes the returned max_sum is still 0? – Qiang Super Jan 12 '21 at 18:55
  • In your original code `max_sum = ...` is only updated on **leaf** nodes (under conditional `if not node.left and not node.right`). at this point, it is too late to update the `max_sum` that was passed to all the other recursive calls on earlier **branch** nodes, eg `if node.left: ...` and `if node.right`, which all received `max_sum = 0`. This is why I am highly recommending that you do not use mutation or variable reassignment with recursion because it is the source of many confusing aspects. – Mulan Jan 12 '21 at 19:07
  • Thank you for your help. A good lesson to learn from you. – Qiang Super Jan 12 '21 at 19:41