any recursive program can be made stack-safe
I have written a lot on the topic of recursion and I am sad when people misstate the facts. And no, this does not rely on daft techniques like sys.setrecursionlimit()
.
Calling a function in python adds a stack frame. So instead of writing f(x)
to call a function, we will write call(f,x)
. Now we have complete control of the evaluation strategy -
# btree.py
def depth(t):
if not t:
return 0
else:
return call \
( lambda left_height, right_height: 1 + max(left_height, right_height)
, call(depth, t.left)
, call(depth, t.right)
)
It's effectively the exact same program. So what is call
?
# tailrec.py
class call:
def __init__(self, f, *v):
self.f = f
self.v = v
So call
is a simple object with two properties: the function to call, f
, and the values to call it with, v
. That means depth
is returning a call
object instead of a number that we need. Only one more adjustment is needed -
# btree.py
from tailrec import loop, call
def depth(t):
def aux(t): # <- auxiliary wrapper
if not t:
return 0
else:
return call \
( lambda l, r: 1 + max(l, r)
, call(aux, t.left)
, call(aux, t.right)
)
return loop(aux(t)) # <- call loop on result of aux
loop
Now all we need to do is write a sufficiently adept loop
to evaluate our call
expressions. The answer here is a direct translation of the evaluator I wrote in this Q&A (JavaScript). I won't repeat myself here, so if you want to understand how it works, I explain it step-by-step as we build loop
in that post -
# tailrec.py
from functools import reduce
def loop(t, k = identity):
def one(t, k):
if isinstance(t, call):
return call(many, t.v, lambda r: call(one, t.f(*r), k))
else:
return call(k, t)
def many(ts, k):
return call \
( reduce \
( lambda mr, e:
lambda k: call(mr, lambda r: call(one, e, lambda v: call(k, [*r, v])))
, ts
, lambda k: call(k, [])
)
, k
)
return run(one(t, k))
Noticing a pattern? loop
is recursive the same way depth
is, but we recur using call
expressions here too. Notice how loop
sends its output to run
, where unmistakable iteration happens -
# tailrec.py
def run(t):
while isinstance(t, call):
t = t.f(*t.v)
return t
check your work
from btree import node, depth
# 3
# / \
# 9 20
# / \
# 15 7
t = node(3, node(9), node(20, node(15), node(7)))
print(depth(t))
3
stack vs heap
You are no longer limited by python's stack limit of ~1000. We effectively hijacked python's evaluation strategy and wrote our own replacement, loop
. Instead of throwing function call frames on the stack, we trade them for continuations on the heap. Now the only limit is your computer's memory.