21

I was wondering if there is some general method to convert a "normal" recursion with foo(...) + foo(...) as the last call to a tail-recursion.

For example (scala):

def pascal(c: Int, r: Int): Int = {
 if (c == 0 || c == r) 1
 else pascal(c - 1, r - 1) + pascal(c, r - 1)
}

A general solution for functional languages to convert recursive function to a tail-call equivalent:

A simple way is to wrap the non tail-recursive function in the Trampoline monad.

def pascalM(c: Int, r: Int): Trampoline[Int] = {
 if (c == 0 || c == r) Trampoline.done(1)
 else for {
     a <- Trampoline.suspend(pascal(c - 1, r - 1))
     b <- Trampoline.suspend(pascal(c, r - 1))
   } yield a + b
}

val pascal = pascalM(10, 5).run

So the pascal function is not a recursive function anymore. However, the Trampoline monad is a nested structure of the computation that need to be done. Finally, run is a tail-recursive function that walks through the tree-like structure, interpreting it, and finally at the base case returns the value.

A paper from Rúnar Bjanarson on the subject of Trampolines: Stackless Scala With Free Monads

DennisVDB
  • 1,347
  • 1
  • 14
  • 30
  • Hello, great question, but I'd like to point out that your first recursive implementation of `pascal` is incomplete. If you do a `pascal(1,0)` then you'd get a `stackoverflow` exception, so you might want to add a condition `if(c<0 || r<0 || c>r) throw new IllegalArgumentException("Columns can never be bigger than lines")` or perhaps an `ArithmeticException`? – Imad Apr 22 '19 at 11:34

6 Answers6

26

In cases where there is a simple modification to the value of a recursive call, that operation can be moved to the front of the recursive function. The classic example of this is Tail recursion modulo cons, where a simple recursive function in this form:

def recur[A](...):List[A] = {
  ...
  x :: recur(...)
}

which is not tail recursive, is transformed into

def recur[A]{...): List[A] = {
   def consRecur(..., consA: A): List[A] = {
     consA :: ...
     ...
     consrecur(..., ...)
   }
   ...
   consrecur(...,...)
}

Alexlv's example is a variant of this.

This is such a well known situation that some compilers (I know of Prolog and Scheme examples but Scalac does not do this) can detect simple cases and perform this optimisation automatically.

Problems combining multiple calls to recursive functions have no such simple solution. TMRC optimisatin is useless, as you are simply moving the first recursive call to another non-tail position. The only way to reach a tail-recursive solution is remove all but one of the recursive calls; how to do this is entirely context dependent but requires finding an entirely different approach to solving the problem.

As it happens, in some ways your example is similar to the classic Fibonnaci sequence problem; in that case the naive but elegant doubly-recursive solution can be replaced by one which loops forward from the 0th number.

def fib (n: Long): Long = n match {
  case 0 | 1 => n
  case _ => fib( n - 2) + fib( n - 1 )
}

def fib (n: Long): Long = {
  def loop(current: Long, next: => Long, iteration: Long): Long = {
    if (n == iteration) 
      current
    else
      loop(next, current + next, iteration + 1)
  }
  loop(0, 1, 0)
}

For the Fibonnaci sequence, this is the most efficient approach (a streams based solution is just a different expression of this solution that can cache results for subsequent calls). Now, you can also solve your problem by looping forward from c0/r0 (well, c0/r2) and calculating each row in sequence - the difference being that you need to cache the entire previous row. So while this has a similarity to fib, it differs dramatically in the specifics and is also significantly less efficient than your original, doubly-recursive solution.

Here's an approach for your pascal triangle example which can calculate pascal(30,60) efficiently:

def pascal(column: Long, row: Long):Long = {
  type Point = (Long, Long)
  type Points = List[Point]
  type Triangle = Map[Point,Long]
  def above(p: Point) = (p._1, p._2 - 1)
  def aboveLeft(p: Point) = (p._1 - 1, p._2 - 1)
  def find(ps: Points, t: Triangle): Long = ps match {
    // Found the ultimate goal
    case (p :: Nil) if t contains p => t(p)
    // Found an intermediate point: pop the stack and carry on
    case (p :: rest) if t contains p => find(rest, t)
    // Hit a triangle edge, add it to the triangle
    case ((c, r) :: _) if (c == 0) || (c == r) => find(ps, t + ((c,r) -> 1))
    // Triangle contains (c - 1, r - 1)...
    case (p :: _) if t contains aboveLeft(p) => if (t contains above(p))
        // And it contains (c, r - 1)!  Add to the triangle
        find(ps, t + (p -> (t(aboveLeft(p)) + t(above(p)))))
      else
        // Does not contain(c, r -1).  So find that
        find(above(p) :: ps, t)
    // If we get here, we don't have (c - 1, r - 1).  Find that.
    case (p :: _) => find(aboveLeft(p) :: ps, t)
  }
  require(column >= 0 && row >= 0 && column <= row)
  (column, row) match {
    case (c, r) if (c == 0) || (c == r) => 1
    case p => find(List(p), Map())
  }
}

It's efficient, but I think it shows how ugly complex recursive solutions can become as you deform them to become tail recursive. At this point, it may be worth moving to a different model entirely. Continuations or monadic gymnastics might be better.

You want a generic way to transform your function. There isn't one. There are helpful approaches, that's all.

itsbruce
  • 4,825
  • 26
  • 35
  • Nice answer. Rúnar's paper was especially informative, although it may contradict your final claim (depending on exactly what transforms you had in mind). His Trampoline transformation would yield a stack-friendly implementation, even though the exponential runtime would still be a problem. – Aaron Novstrup Sep 23 '13 at 19:19
9

I don't know how theoretical this question is, but a recursive implementation won't be efficient even with tail-recursion. Try computing pascal(30, 60), for example. I don't think you'll get a stack overflow, but be prepared to take a long coffee break.

Instead, consider using a Stream or memoization:

val pascal: Stream[Stream[Long]] = 
  (Stream(1L) 
    #:: (Stream from 1 map { i => 
      // compute row i
      (1L 
        #:: (pascal(i-1) // take the previous row
               sliding 2 // and add adjacent values pairwise
               collect { case Stream(a,b) => a + b }).toStream 
        ++ Stream(1L))
    }))
Community
  • 1
  • 1
Aaron Novstrup
  • 20,967
  • 7
  • 70
  • 108
  • 1
    I realize this doesn't directly answer your question, but I decided to post it as an answer rather than a comment anyway because you're likely to run into the same efficiency issue with any non-trivial recurrence of that form. – Aaron Novstrup Sep 22 '13 at 21:47
  • 3
    If we're doing alternative Pascal's Triangle implementations, how about `val pascal = Stream.iterate(Seq(1))(a=>(0+:a,a:+0).zipped.map(_+_))` – Luigi Plinge Sep 23 '13 at 10:24
  • @LuigiPlinge Beautiful! – Aaron Novstrup Sep 23 '13 at 15:47
  • @AaronNovstrup really nice stream example! – Loic Mar 19 '15 at 15:47
  • @LuigiPlinge I guess there is an issue in your code, I get negative values with it : pascal(60)(30) = -1515254800 – Loic Mar 19 '15 at 15:48
  • @Loic you're overflowing the limits of `Int`. Exercise for you: convert this to use `BigInt`. Hint: it's very easy :) – Luigi Plinge Mar 19 '15 at 18:24
  • @LuigiPlinge Is there a way to avoid memoization of all lines, to keep only the last line , to avoid memory issues for high numbers? Note : Now that I understand it (works fine with BigInt) I find your solution very elegant! Good job :) – Loic Mar 20 '15 at 08:54
  • @Loic Use `Iterator.iterate` instead. Unfortunately it doesn't have such a handy way to access elements (since you can only access elements once, being an iterator and all), so you have to to `drop` however many you want then look at `next` to give your result. – Luigi Plinge Mar 20 '15 at 14:46
5

The accumulator approach

  def pascal(c: Int, r: Int): Int = {

    def pascalAcc(acc:Int, leftover: List[(Int, Int)]):Int = {
      if (leftover.isEmpty) acc
      else {
        val (c1, r1) = leftover.head
        // Edge.
        if (c1 == 0 || c1 == r1) pascalAcc(acc + 1, leftover.tail)
        // Safe checks.
        else if (c1 < 0 || r1 < 0 || c1 > r1) pascalAcc(acc, leftover.tail)
        // Add 2 other points to accumulator.
        else pascalAcc(acc, (c1 , r1 - 1) :: ((c1 - 1, r1 - 1) :: leftover.tail ))
      }
    }

    pascalAcc(0, List ((c,r) ))
  }

It does not overflow the stack but as on big row and column but Aaron mentioned it's not fast.

Alex des Pelagos
  • 1,170
  • 7
  • 8
4

Yes it's possible. Usually it's done with accumulator pattern through some internally defined function, which has one additional argument with so called accumulator logic, example with counting length of a list.

For example normal recursive version would look like this:

def length[A](xs: List[A]): Int = if (xs.isEmpty) 0 else 1 + length(xs.tail)

that's not a tail recursive version, in order to eliminate last addition operation we have to accumulate values while somehow, for example with accumulator pattern:

def length[A](xs: List[A]) = {
  def inner(ys: List[A], acc: Int): Int = {
    if (ys.isEmpty) acc else inner(ys.tail, acc + 1)
  }
  inner(xs, 0)
}

a bit longer to code, but i think the idea i clear. Of cause you can do it without inner function, but in such case you should provide acc initial value manually.

4lex1v
  • 21,367
  • 6
  • 52
  • 86
  • 2
    One major difference is that in the pascal example, you have to recurse twice. You could stick the result of the first one into an accumulator, but getting it in the first place won't be TCO'ed. How would one get around that? – yshavit Sep 22 '13 at 15:44
  • @yshavit can't check this solution, but maybe with two accumulators, return tuple from tailrec inner function and then sum? – 4lex1v Sep 22 '13 at 16:25
  • My gut is that it's not possible through an accumulator (without, as Luigi mentions below, simulating a call stack Ina local variable). – yshavit Sep 22 '13 at 20:53
  • @yshavit yep, agree with you. But still my example with accumulator is the most used approach to writing tailrec functions in scala – 4lex1v Sep 22 '13 at 21:06
  • 3
    @AlexIv but that doesn't make it an answer to this particular question, which you do not appear to have read carefully enough. – itsbruce Sep 22 '13 at 22:28
  • @itsbruce it was about general pattern to convert recursive function to tailrecursive, i gave example with accumulator pattern. As for pascals triangle i don't know how to convert this algorithm into TR. – 4lex1v Sep 22 '13 at 22:31
  • 2
    It's a question with a test case which your answer appears to ignore. Without showing how you might solve the test case, you're not really answering the question. – itsbruce Sep 22 '13 at 23:12
3

I'm pretty sure it's not possible in the simple way you're looking for the general case, but it would depend on how elaborate you permit the changes to be.

A tail-recursive function must be re-writable as a while-loop, but try implementing for example a Fractal Tree using while-loops. It's possble, but you need to use an array or collection to store the state for each point, which susbstitutes for the data otherwise stored in the call-stack.

It's also possible to use trampolining.

Luigi Plinge
  • 50,650
  • 20
  • 113
  • 180
  • Yes trampolines are the easiest way to force a non tail recursive function to be tail recursive. See http://www.scala-lang.org/api/current/index.html#scala.util.control.TailCalls$ and http://blog.richdougherty.com/2009/04/tail-calls-tailrec-and-trampolines.html – iain Sep 23 '13 at 13:37
2

It is indeed possible. The way I'd do this is to begin with List(1) and keep recursing till you get to the row you want. Worth noticing that you can optimize it: if c==0 or c==r the value is one, and to calculate let's say column 3 of the 100th row you still only need to calculate the first three elements of the previous rows. A working tail recursive solution would be this:

def pascal(c: Int, r: Int): Int = {
  @tailrec
  def pascalAcc(c: Int, r: Int, acc: List[Int]): List[Int] = {
    if (r == 0) acc
    else pascalAcc(c, r - 1,
    // from let's say 1 3 3 1 builds 0 1 3 3 1 0 , takes only the
    // subset that matters (if asking for col c, no cols after c are
    // used) and uses sliding to build (0 1) (1 3) (3 3) etc.
      (0 +: acc :+ 0).take(c + 2)
         .sliding(2, 1).map { x => x.reduce(_ + _) }.toList)
  }
  if (c == 0 || c == r) 1
  else pascalAcc(c, r, List(1))(c)
}

The annotation @tailrec actually makes the compiler check the function is actually tail recursive. It could be probably be further optimized since given that the rows are symmetric, if c > r/2, pascal(c,r) == pascal ( r-c,r).. but left to the reader ;)

Roberto Congiu
  • 5,123
  • 1
  • 27
  • 37