I'm implementing a Trampoline in Python, in order to write recursive functions with stack safety (since CPython does not feature TCO). It looks like this:
from typing import Generic, TypeVar
from abc import ABC, abstractmethod
A = TypeVar('A', covariant=True)
class Trampoline(Generic[A], ABC):
"""
Base class for Trampolines. Useful for writing stack safe-safe
recursive functions.
"""
@abstractmethod
def _resume(self) -> 'Trampoline[A]':
"""
Let this trampoline resume the interpreter loop
"""
pass
@abstractmethod
def _handle_cont(
self, cont: Callable[[A], 'Trampoline[B]']
) -> 'Trampoline[B]':
"""
Handle continuation function passed to `and_then`
"""
pass
@property
def _is_done(self) -> bool:
return isinstance(self, Done)
def and_then(self, f: Callable[[A], 'Trampoline[B]']) -> 'Trampoline[B]':
"""
Apply ``f`` to the value wrapped by this trampoline.
Args:
f: function to apply the value in this trampoline
Return:
Result of applying ``f`` to the value wrapped by \
this trampoline
"""
return AndThen(self, f)
def map(self, f: Callable[[A], B]) -> 'Trampoline[B]':
"""
Map ``f`` over the value wrapped by this trampoline.
Args:
f: function to wrap over this trampoline
Return:
new trampoline wrapping the result of ``f``
"""
return self.and_then(lambda a: Done(f(a)))
def run(self) -> A:
"""
Interpret a structure of trampolines to produce a result
Return:
result of intepreting this structure of \
trampolines
"""
trampoline = self
while not trampoline._is_done:
trampoline = trampoline._resume()
return cast(Done[A], trampoline).a
class Done(Trampoline[A]):
"""
Represents the result of a recursive computation.
"""
a: A
def _resume(self) -> Trampoline[A]:
return self
def _handle_cont(self,
cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
return cont(self.a)
class Call(Trampoline[A]):
"""
Represents a recursive call.
"""
thunk: Callable[[], Trampoline[A]]
def _handle_cont(self,
cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
return self.thunk().and_then(cont) # type: ignore
def _resume(self) -> Trampoline[A]:
return self.thunk() # type: ignore
class AndThen(Generic[A, B], Trampoline[B]):
"""
Represents monadic bind for trampolines as a class to avoid
deep recursive calls to ``Trampoline.run`` during interpretation.
"""
sub: Trampoline[A]
cont: Callable[[A], Trampoline[B]]
def _handle_cont(self,
cont: Callable[[B], Trampoline[C]]) -> Trampoline[C]:
return self.sub.and_then(self.cont).and_then(cont) # type: ignore
def _resume(self) -> Trampoline[B]:
return self.sub._handle_cont(self.cont) # type: ignore
def and_then( # type: ignore
self, f: Callable[[A], Trampoline[B]]
) -> Trampoline[B]:
return AndThen(
self.sub,
lambda x: Call(lambda: self.cont(x).and_then(f)) # type: ignore
)
Now, I need a monadic sequence operator. My initial take looked like this:
from typing import Iterable
from functools import reduce
def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
def combine(result: Trampoline[Iterable[A]], ta: Trampoline[A]) -> Trampoline[Iterable[A]]:
return result.and_then(lambda as_: ta.map(lambda a: as_ + (a,)))
return reduce(combine, iterable, Done(()))
That works, but the overhead of all the function calls resulting from reducing a long list of trampolines in this way absolutely kills performance.
So instead I tried this:
def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
def thunk() -> Trampoline[Iterable[A]]:
return Done(tuple([t.run() for t in iterable]))
return Call(thunk)
Now, my gut feeling is that the second solution of sequence
isn't stack safe because it call's run
, which means that run
will be calling run
during interpretation (through Call.thunk
but non the less). However, I can't seem to produce a stack overflow no matter how I mix and match.
For example, I thought this should do it:
t, *ts = [sequence(Done(v) for v in range(2)) for _ in range(10000)]
def combine(t1, t2):
return t1.and_then(lambda _: t2)
final = reduce(combine, ts, t)
final.run() # My gut feeling says this should overflow the stack, but it doesn't
I've tried countless other examples, but no stack overflow. My gut feeling remains that this shouldn't work.
I need someone to convince me that trampolining the interpreter loop in this way is actually stack safe, or show me an example where it overflows the stack