0

I am now working on the DFS method to Count Paths for a Sum. The Problem Statement is:

Given a binary tree and a number ‘S’, find all paths in the tree such that the sum of all the node values of each path equals ‘S’. Please note that the paths can start or end at any node but all paths must follow direction from parent to child (top to bottom).

My approach is as:

def all_sum_path(root, target):
    global count
    count = 0
    find_sum_path(root, target, [])
    return count

def find_sum_path(root, target, allPath):
    global count
    if not root:
        return 0
    # add a space for current node
    allPath.append(0)
    # add current node values to all path
    allPath = [i+root.value for i in allPath]
    print(allPath)
    # check if current path == target
    for j in allPath:
        if j == target:
            count += 1
    # recursive
    find_sum_path(root.left, target, allPath)
    find_sum_path(root.right, target, allPath)
    # remove the current path
    print('after', allPath)
    allPath.pop()
    print('after pop', allPath)

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(all_sum_path(root, 11))

main()

with return:

[1]
[8, 7]
[14, 13, 6]
after [14, 13, 6]
after pop [14, 13]
[13, 12, 5, 5]
after [13, 12, 5, 5]
after pop [13, 12, 5]
after [8, 7, 0, 0]
after pop [8, 7, 0]
[10, 9, 9]
[12, 11, 11, 2]
after [12, 11, 11, 2]
after pop [12, 11, 11]
[13, 12, 12, 3, 3]
after [13, 12, 12, 3, 3]
after pop [13, 12, 12, 3]
after [10, 9, 9, 0, 0]
after pop [10, 9, 9, 0]
after [1, 0, 0]
after pop [1, 0]
4

I think the issue is that I didn't successfully delete the right most node in the list. Then I updated my code as following, where I deleted the right most node of allPath and crate a new list named newAllPath to record the nodes that already plussed the value of current node.

def all_sum_path(root, target):
    global count
    count = 0
    find_sum_path(root, target, [])
    return count

def find_sum_path(root, target, allPath):
    global count
    if not root:
        return 0
    # add a space for current node
    allPath.append(0)
    # add current node values to all path
    newAllPath = [i+root.value for i in allPath]
    print(allPath, newAllPath)
    # check if current path == target
    for j in newAllPath:
        if j == target:
            count += 1
    # recursive
    find_sum_path(root.left, target, newAllPath)
    find_sum_path(root.right, target, newAllPath)
    # remove the current path
    print('after', allPath, newAllPath)
    allPath.pop()
    print('after pop', allPath, newAllPath)

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(1)
    root.left = TreeNode(7)
    root.right = TreeNode(9)
    root.left.left = TreeNode(6)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(2)
    root.right.right = TreeNode(3)

    print(all_sum_path(root, 12))

    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(all_sum_path(root, 11))

main()

with return:

[0] [1]
[1, 0] [8, 7]
[8, 7, 0] [14, 13, 6]
after [8, 7, 0] [14, 13, 6]
after pop [8, 7] [14, 13, 6]
[8, 7, 0] [13, 12, 5]
after [8, 7, 0] [13, 12, 5]
after pop [8, 7] [13, 12, 5]
after [1, 0] [8, 7]
after pop [1] [8, 7]
[1, 0] [10, 9]
[10, 9, 0] [12, 11, 2]
after [10, 9, 0] [12, 11, 2]
after pop [10, 9] [12, 11, 2]
[10, 9, 0] [13, 12, 3]
after [10, 9, 0] [13, 12, 3]
after pop [10, 9] [13, 12, 3]
after [1, 0] [10, 9]
after pop [1] [10, 9]
after [0] [1]
after pop [] [1]
3

I am not sure why I can't successfully delete the right most node in my first approach. However, in my second approach, once I removed the right most node in the allPath, it will also remove the node in the newAllPath.

Thanks for your help. I am so confusing and get stuck here for the whole day.

Qiang Super
  • 323
  • 1
  • 11

2 Answers2

2

Unless your tree is more than 1000 nodes deep, you could use recursion to get simpler code:

def findSums(node,target):
    if not node : return
    if node.value == target: yield [node.value]           # target reached, return path
    for child in (node.left,node.right):                  # traverse tree DFS
        yield from findSums(child,target)                 # paths skipping this node 
        for subPath in findSums(child,target-node.value): # paths with remainder
            yield [node.value]+subPath                    # value + sub-path

for sp in findSums(root,11):
    print(sp)

# [7, 4]
# [1, 10]

To print your binary tree, see this: https://stackoverflow.com/a/49844237/5237560

Alain T.
  • 40,517
  • 4
  • 31
  • 51
1

functional principles

This is a non-trivial problem and I'm not going to attempt debugging your program as it goes against the way recursion is intended to be used. To reiterate, recursion is a functional heritage and so using it with functional style yields the best results. This means avoiding -

  • mutations like .append, += 1, .pop
  • reassignments like left = ..., right = ..., allPath = ...
  • globals like count
  • other side effects like print

decomposition

It's a bad idea to try to wrap all the concerns of your task into a single function. There are numerous benefits to breaking the problem down into separate parts -

  • smaller functions are easier to read, write, test, and debug
  • single-purpose functions are easier to reuse

To start, we're going to use find_sum that we already wrote in your previous Q&A -

def find_sum(t, q, path = []):
  if not t:
    return
  elif t.value == q:
    yield [*path, t.value]
  else:
    yield from find_sum(t.left, q - t.value, [*path, t.value])
    yield from find_sum(t.right, q - t.value, [*path, t.value])

Using find_sum we can easily write all_sum -

def all_sum(t, q):
  for n in traverse(t):
    yield from find_sum(n, q)

Which requires us to write a generic traverse -

def traverse(t):
  if not t:
    return
  else:
    yield from traverse(t.left)
    yield t
    yield from traverse(t.right)

Let's see it work on a sample tree -

               12
             /    \
            /      \
           7        1
          / \      / \
         4   3    10  5
            /    /
           1    1

Which we represent using your TreeNode constructor -

t1 = TreeNode \
  ( 12
  , TreeNode(7, TreeNode(4), TreeNode(3, TreeNode(1)))
  , TreeNode(1, TreeNode(10, TreeNode(1)), TreeNode(5))
  )

print(list(all_sum(t1, 11)))
[[7, 4], [7, 3, 1], [10, 1], [1, 10]]

counting all sums

If the only goal is to count the sums, we can write count_all_sum as a simple wrapper around all_sum -

def count_all_sum (t, q):
  return len(list(all_sum(t, q)))
Mulan
  • 129,518
  • 31
  • 228
  • 259
  • Thank you for your help. I will try my best to apply your algorithm to solve the problem of DFS. I am from statistics and trying to become a DS for IT company. – Qiang Super Jan 13 '21 at 03:07