4

I'm implementing a Trampoline in Python, in order to write recursive functions with stack safety (since CPython does not feature TCO). It looks like this:

from typing import Generic, TypeVar
from abc import ABC, abstractmethod

A = TypeVar('A', covariant=True)


class Trampoline(Generic[A], ABC):
    """
    Base class for Trampolines. Useful for writing stack safe-safe
    recursive functions.
    """
    @abstractmethod
    def _resume(self) -> 'Trampoline[A]':
        """
        Let this trampoline resume the interpreter loop
        """
        pass

    @abstractmethod
    def _handle_cont(
        self, cont: Callable[[A], 'Trampoline[B]']
    ) -> 'Trampoline[B]':
        """
        Handle continuation function passed to `and_then`
        """
        pass

    @property
    def _is_done(self) -> bool:
        return isinstance(self, Done)

    def and_then(self, f: Callable[[A], 'Trampoline[B]']) -> 'Trampoline[B]':
        """
        Apply ``f`` to the value wrapped by this trampoline.

        Args:
            f: function to apply the value in this trampoline
        Return:
            Result of applying ``f`` to the value wrapped by \
            this trampoline
        """
        return AndThen(self, f)

    def map(self, f: Callable[[A], B]) -> 'Trampoline[B]':
        """
        Map ``f`` over the value wrapped by this trampoline.

        Args:
            f: function to wrap over this trampoline
        Return:
            new trampoline wrapping the result of ``f``
        """
        return self.and_then(lambda a: Done(f(a)))

    def run(self) -> A:
        """
        Interpret a structure of trampolines to produce a result

        Return:
            result of intepreting this structure of \
            trampolines
        """
        trampoline = self
        while not trampoline._is_done:
            trampoline = trampoline._resume()

        return cast(Done[A], trampoline).a


class Done(Trampoline[A]):
    """
    Represents the result of a recursive computation.
    """
    a: A

    def _resume(self) -> Trampoline[A]:
        return self

    def _handle_cont(self,
                     cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
        return cont(self.a)


class Call(Trampoline[A]):
    """
    Represents a recursive call.
    """
    thunk: Callable[[], Trampoline[A]]

    def _handle_cont(self,
                     cont: Callable[[A], Trampoline[B]]) -> Trampoline[B]:
        return self.thunk().and_then(cont)  # type: ignore

    def _resume(self) -> Trampoline[A]:
        return self.thunk()  # type: ignore


class AndThen(Generic[A, B], Trampoline[B]):
    """
    Represents monadic bind for trampolines as a class to avoid
    deep recursive calls to ``Trampoline.run`` during interpretation.
    """
    sub: Trampoline[A]
    cont: Callable[[A], Trampoline[B]]

    def _handle_cont(self,
                     cont: Callable[[B], Trampoline[C]]) -> Trampoline[C]:
        return self.sub.and_then(self.cont).and_then(cont)  # type: ignore

    def _resume(self) -> Trampoline[B]:
        return self.sub._handle_cont(self.cont)  # type: ignore

    def and_then(  # type: ignore
        self, f: Callable[[A], Trampoline[B]]
    ) -> Trampoline[B]:
        return AndThen(
            self.sub,
            lambda x: Call(lambda: self.cont(x).and_then(f))  # type: ignore
        )

Now, I need a monadic sequence operator. My initial take looked like this:

from typing import Iterable

from functools import reduce


def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
    def combine(result: Trampoline[Iterable[A]], ta: Trampoline[A]) -> Trampoline[Iterable[A]]:
        return result.and_then(lambda as_: ta.map(lambda a: as_ + (a,)))

    return reduce(combine, iterable, Done(()))

That works, but the overhead of all the function calls resulting from reducing a long list of trampolines in this way absolutely kills performance.

So instead I tried this:

def sequence(iterable: Iterable[Trampoline[A]]) -> Trampoline[Iterable[A]]:
    def thunk() -> Trampoline[Iterable[A]]:
        return Done(tuple([t.run() for t in iterable]))
    
    return Call(thunk)

Now, my gut feeling is that the second solution of sequence isn't stack safe because it call's run, which means that run will be calling run during interpretation (through Call.thunk but non the less). However, I can't seem to produce a stack overflow no matter how I mix and match.

For example, I thought this should do it:

t, *ts = [sequence(Done(v) for v in range(2)) for _ in range(10000)]

def combine(t1, t2):
    return t1.and_then(lambda _: t2)

final = reduce(combine, ts, t)
final.run()  # My gut feeling says this should overflow the stack, but it doesn't

I've tried countless other examples, but no stack overflow. My gut feeling remains that this shouldn't work.

I need someone to convince me that trampolining the interpreter loop in this way is actually stack safe, or show me an example where it overflows the stack

  • Are you interested in another implementation? - [https://pypi.org/project/trampoline/](https://pypi.org/project/trampoline/) – wwii Aug 11 '20 at 18:33
  • I like the api definitely, nice with the simplicity of using generators! For my use case I need to be able to chain together trampolined functions (ie monadic bind called `and_then` in my example) and trampoline combinators (e.g `sequence` as described`. I'm not sure how that would work with your solution? Probably its possible. Also, how does it work if you want to trampoline a generator function? (ie where the yielded values doesn't indicate a recursive call but the actual return values of the function? – Sune Andreas Dybro Debel Aug 11 '20 at 19:16
  • It is not my solution, just wanted to point it out in case you had not seen it. – wwii Aug 11 '20 at 22:44
  • 1
    It makes good sense to put the interpreter on a trampoline. Using continuation passing style is an effective way to convert a recursive process to a linear one. I use this same technique in another (JS) [Q&A](https://stackoverflow.com/a/57743504/633183). – Mulan Aug 12 '20 at 04:28

1 Answers1

0

The recursion you need to cause a stack overflow during interpretation:

sequence([sequence([sequence([sequence([...
  • 1
    I don't know python but if you are recursive within a monad, i.e. within bind/chain/then then you need a trampoline- or continuation-monad, where the former would be simpler than the latter. Normal trampolines depend on tail recursion. If you want to relax this restriction you can also implement a trampoline modulon cons, which allows you to place the recursive step within a constructor apllication. I wrote a lot about trampolines. You can find it on my profile. –  Aug 12 '20 at 14:18