3

I am making an effort to solve a problem Maximum Depth of Binary Tree - LeetCode

The problem is given as an exercise on tail recursion in a leetcode tutorial. tail recursion - LeetCode

Given a binary tree, find its maximum depth.

The maximum depth is the number of nodes along the longest path from the root node down to the farthest leaf node.

Note: A leaf is a node with no children.

Example:

Given binary tree [3,9,20,null,null,15,7],

    3
   / \
  9  20
    /  \
   15   7

return its depth = 3.

A standard solution which views the problem from the definition of level

class Solution:
    def maxDepth(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """ 
        if root is None: 
            return 0 
        else: 
            left_height = self.maxDepth(root.left) 
            right_height = self.maxDepth(root.right) 
            return max(left_height, right_height) + 1 

However, it's not a tail recursive

Tail recursion is a recursion where the recursive call is the final instruction in the recursion function. And there should be only one recursive call in the function.

I read all other submissions and the discussions but did not find a tail recursive solution.

How could solve the problem using tail recursion?

LeCodex
  • 1,636
  • 14
  • 20
Alice
  • 1,360
  • 2
  • 13
  • 28

4 Answers4

3

any recursive program can be made stack-safe

I have written a lot on the topic of recursion and I am sad when people misstate the facts. And no, this does not rely on daft techniques like sys.setrecursionlimit().

Calling a function in python adds a stack frame. So instead of writing f(x) to call a function, we will write call(f,x). Now we have complete control of the evaluation strategy -

# btree.py

def depth(t):
  if not t:
    return 0
  else:
    return call \
      ( lambda left_height, right_height: 1 + max(left_height, right_height)
      , call(depth, t.left)
      , call(depth, t.right)
      )

It's effectively the exact same program. So what is call?

# tailrec.py

class call:
  def __init__(self, f, *v):
    self.f = f
    self.v = v

So call is a simple object with two properties: the function to call, f, and the values to call it with, v. That means depth is returning a call object instead of a number that we need. Only one more adjustment is needed -

# btree.py

from tailrec import loop, call

def depth(t):
  def aux(t):                        # <- auxiliary wrapper
    if not t:
      return 0
    else:
      return call \
        ( lambda l, r: 1 + max(l, r)
        , call(aux, t.left)
        , call(aux, t.right)
        )
  return loop(aux(t))                # <- call loop on result of aux

loop

Now all we need to do is write a sufficiently adept loop to evaluate our call expressions. The answer here is a direct translation of the evaluator I wrote in this Q&A (JavaScript). I won't repeat myself here, so if you want to understand how it works, I explain it step-by-step as we build loop in that post -

# tailrec.py

from functools import reduce

def loop(t, k = identity):
  def one(t, k):
    if isinstance(t, call):
      return call(many, t.v, lambda r: call(one, t.f(*r), k))
    else:
      return call(k, t)
  def many(ts, k):
    return call \
      ( reduce \
          ( lambda mr, e:
              lambda k: call(mr, lambda r: call(one, e, lambda v: call(k, [*r, v])))
          , ts
          , lambda k: call(k, [])
          )
      , k
      )
  return run(one(t, k))

Noticing a pattern? loop is recursive the same way depth is, but we recur using call expressions here too. Notice how loop sends its output to run, where unmistakable iteration happens -

# tailrec.py

def run(t):
  while isinstance(t, call):
    t = t.f(*t.v)
  return t

check your work

from btree import node, depth

#   3
#  / \
# 9  20
#   /  \
#  15   7

t = node(3, node(9), node(20, node(15), node(7)))

print(depth(t))
3

stack vs heap

You are no longer limited by python's stack limit of ~1000. We effectively hijacked python's evaluation strategy and wrote our own replacement, loop. Instead of throwing function call frames on the stack, we trade them for continuations on the heap. Now the only limit is your computer's memory.

Mulan
  • 129,518
  • 31
  • 228
  • 259
2

Every recursive algorithm can be turned into tail-recursive one. Sometimes it is just not straightforward and you need to use a slightly different approach.

In case of tail-recursive algorithm for determining the depth of a binary tree, you can traverse the tree by cumulating the list of subtrees to be visited together with the depth information. Hence, your list will be a list of tuples (depth: Int, node: tree) and your second accumulator will record the maximum depth.

Here is a general outline of the algorithm

  • start with a list toVisit containing a tuple (1, rootNode) and maxDepth set to 0
  1. if the toVisit list is empty, return maxValue
  2. pop head from the list
  3. if the head is the EmptyTree, continue with tail, maxValue stays the same
  4. if the head is a Node, update toVisit by adding left and right subtree to its tail, incrementing the depth in the tuple and check whether the depth of the popped head is bigger then the one stored in maxDepth accumulator

Here is a Scala implementation

abstract class Tree[+A] {
  def head: A
  def left: Tree[A]
  def right: Tree[A]
  def depth: Int
  ...
}
case object EmptyTree extends Tree[Nothing] {...}

case class Node[+A](h: A, l: Tree[A], r: Tree[A]) extends Tree[A] {

  override def depth: Int = {

    @tailrec
    def depthAux(toVisit: List[(Int, Tree[A])], maxDepth: Int): Int = toVisit match {
      case Nil => maxDepth
      case head :: tail => {
        val depth = head._1
        val node = head._2
        if (node.isEmpty) depthAux(tail, maxDepth)
        else depthAux(toVisit = tail ++ List((depth + 1, node.left), (depth + 1, node.right)),
                      maxDepth = if (depth > maxDepth) depth else maxDepth)
      }
    }

    depthAux(List((1, this)), 0)
  }
 ...
}


And for those who are more interested in Haskell

data Tree a = Empty | Node a (Tree a) (Tree a) deriving (Show)

depthAux :: [(Int, Tree a)] -> Int -> Int
depthAux [] maxDepth = maxDepth
depthAux ((depth, Empty):xs) maxDepth = depthAux xs maxDepth
depthAux ((depth, (Node h l r)):xs) maxDepth = 
    depthAux (xs ++ [(depth + 1, l), (depth + 1, r)]) (max depth maxDepth) 

depth :: Tree a -> Int
depth node = depthAux [(1, node)] 0
Matus Dubrava
  • 13,637
  • 2
  • 38
  • 54
0

You can't. Trivially you can see that it's impossible to eliminate both all the LHS tail calls and the RHS tail calls. You can eliminate one but not the other. Let's talk about that.


Let's open by stating bluntly that recursion is generally a bad idea in Python. It is not optimized for recursive solutions, and even trivial optimizations (like tail call elimination) are not implemented. Don't do this here.

However, it can be a good language for illustrating concepts that might be harder to grasp in other languages (even if those might be better suited to the kinds of solutions you're looking for) so let's dive in.

As you seem to understand: recursion is a function calling itself. While each function's logic may change, they all have two major sections:

  1. Base case

This is the trivial case that is usually something like return 1 or other degenerate case

  1. Recursive case

Here's where the function decides it has to go deeper and recurses into itself.

For tail recursion, the important part is that in the recursive case, the function doesn't have to do anything after it recurses. More optimized languages can deduce this and immediately throw away the stack frame containing the context for the old call as soon as it recurses into the new call. This is often done by passing required context through function parameters.

Imagine a sum function implemented this way

def sum_iterative(some_iterable: List[int]) -> int:
    total = 0
    for num in some_iterable:
        total += num
    return total

def sum_recursive(some_iterable: List[int]) -> int:
    """This is a wrapper function that implements sum recursively."""

    def go(total: int, iterable: List[int]) -> int:
        """This actually does the recursion."""
        if not iterable:  # BASE CASE if the iterable is empty
            return 0
        else:             # RECURSIVE CASE
            head = iterable.pop(0)
            return go(total+head, iterable)

    return go(0, some_iterable)

Do you see how I've had to define a helper function that takes some arguments that aren't naturally passed in by the user? That can help you with this.

def max_depth(root: Optional[TreeNode]) -> int:
    def go(maxdepth: int, curdepth: int, node: Optional[TreeNode]) -> int:
        if node is None:
            return maxdepth
        else:
            curdepth += 1
            lhs_max = go(max(maxdepth, curdepth), curdepth, node.left)
            # the above is the call that cannot be eliminated
            return go(max(lhs_max, curdepth), curdepth, node.right)
    return go(0, 0, root)

For fun, here's a really ugly example in Haskell (because I felt like brushing up on my functional)

data TreeNode a = TreeNode { val   :: a
                           , left  :: Maybe (TreeNode a)
                           , right :: Maybe (TreeNode a)
                           }
treeDepth :: TreeNode a -> Int
treeDepth = go 0 0 . Just
  where go :: Int -> Int -> (Maybe (TreeNode a)) -> Int
        go maxDepth _        Nothing     = maxDepth
        go maxDepth curDepth (Just node) = let curDepth' = curDepth + 1 :: Int
                                               maxDepth' = max maxDepth curDepth' :: Int
                                               lhsMax    = go maxDepth' curDepth' (left node)
                                           in  go lhsMax curDepth' (right node)

root = TreeNode 3 (Just (TreeNode 9 Nothing Nothing)) (Just (TreeNode 20 (Just (TreeNode 15 Nothing Nothing)) (Just (TreeNode 7 Nothing Nothing)))) :: TreeNode Int

main :: IO ()
main = print $ treeDepth root
Adam Smith
  • 52,157
  • 12
  • 73
  • 112
  • My original code had a typo in `go` that was passing `root.left` and `root.right` to the recursion rather than `node.left` and `node.right` so it was recursing infinitely. Whoops! – Adam Smith Apr 18 '19 at 02:57
0

It's maybe a bit late, but you could pass a list of subtrees and always remove the root element. For each recursion you can count the amount of deletions.

Here an implementation in Haskell

data Tree a 
    = Leaf a
    | Node a (Tree a) (Tree a)
    deriving Show

depth :: Tree a -> Integer
depth tree = recursion 0 [tree]
    where 
        recursion :: Integer -> [Tree a] -> Integer
        recursion n [] = n
        recursion n treeList = recursion (n+1) (concatMap f treeList)
            where
                f (Leaf _) = []
                f (Node _ left right) = [left, right]

root = Node 1 (Node 2 (Leaf 3) (Leaf 3)) (Leaf 7)

main :: IO ()
main = print $ depth root
Pandermatt
  • 21
  • 3