0

I am doing the above leetcode problem in Python. Typically what I do is I solve the problem in a jupyter notebook and then copy and paste it into the leetcode solution box once I am done with it. I am having issues with this problem, however.

The problem definition is defined below:

Given the root of a Binary Search Tree (BST), convert it to a Greater Tree such that every key of the original BST is changed to the original key plus sum of all keys greater than the original key in BST.

As a reminder, a binary search tree is a tree that satisfies these constraints:

The left subtree of a node contains only nodes with keys less than the node's key. The right subtree of a node contains only nodes with keys greater than the node's key. Both the left and right subtrees must also be binary search trees.

A sample input and output for the problem is shown below

Input: root = [4,1,6,0,2,5,7,null,null,null,3,null,null,null,8]
Output: [30,36,21,36,35,26,15,null,null,null,33,null,null,null,8]

Furthermore the problem solution is set up as follows

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def bstToGst(self, root: TreeNode) -> TreeNode:

I am confused as to how to approach this problem. Initially I thought I would do some sort of looping through the list provided. Upon reading some sample responses from the discussion, however, I see that commands such as root.right and root.left are used. How do I go about doing this in a jupyter notebook? I have no expereince with TreeNodes so I want to this the problem the right way and learn the fundamental concept instead of brute forcing through it another way. All help is greatly appreciated.

Thanks

  • Use a [reverse-order-traversal](https://en.wikipedia.org/wiki/Tree_traversal#Reverse_in-order,_RNL) to read the nodes in order from largest to smallest. Maintain a global accumulator initialized to 0. When you read any node, add the global accumulator's value to the node, then add the node's original value to the accumulator – inspectorG4dget Feb 02 '21 at 22:30

1 Answers1

0

Define an iterator function to traverse the nodes in reverse order, Then accumulate the total going backwards through the nodes and assign the values to each node:

class Solution:
    def revNodes(self,node):
        if node.right: yield from self.revNodes(node.right)
        yield node
        if node.left:  yield from self.revNodes(node.left)
        
    def bstToGst(self, root):
        total = 0
        for node in self.revNodes(root):
            node.val = total = node.val + total

output:

data  = [4,1,6,0,2,5,7,None,None,None,3,None,None,None,8]
nodes = [v if v is None else TreeNode(v) for v in data]
for i,node in enumerate(nodes):
    if not node: continue
    if 2*i+1<len(nodes): node.left  = nodes[2*i+1]
    if 2*i+2<len(nodes): node.right = nodes[2*i+2]    
root = nodes[0]


print(root) # BEFORE

      4
   __/ \_
  1      6
 / \    / \
0   2  5   7
     \      \
      3      8

Solution().bstToGst(root)

print([node.val if node else None for node in nodes])
[30, 36, 21, 36, 35, 26, 15, None, None, None, 33, None, None, None, 8]    

print(root) # AFTER

        30
     __/  \__
   36        21
  /  \      /  \
36    35  26    15
        \         \
         33        8

Note that, in order to print the tree, I had to add a repr() method to the TreeNode class

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
        
    def __repr__(self):
        nodeInfo = lambda n:(str(n.val),n.left,n.right)
        return "\n".join(printBTree(self,nodeInfo,isTop=False))

The printBTree function is from another answer I provided in the past here

Alain T.
  • 40,517
  • 4
  • 31
  • 51