3

I have a function in scala that I wonder if it's possible to make into a tail recursive function.

def get_f(f: Int => Int, x: Int, y: Int): Int = x match {
  case 0 => y
  case _ => f(get_f(f, x - 1, y))
}
dk14
  • 22,206
  • 4
  • 51
  • 88
  • @pavel I don't get any errors, but I want to avoid stackoverflows in the future and make normal recursion into tail recursion –  Dec 06 '16 at 14:32

5 Answers5

7

I see that this function applies f function to result recursivly, x times. It's the same as applying it to y, x times. Also I suggest you to use if else instead of pattern matching.

@tailrec
def get_f(f: Int => Int, x: Int, y: Int): Int = 
    if(x == 0) y
    else get_f(f, x - 1, f(y))

Add @tailrec annotation to ensure that it is tail recursive

Murat Mustafin
  • 1,284
  • 10
  • 17
4

It is possible but the way you've constructed it means you're going to have to use a Trampolined style to make it work:

import scala.util.control.TailCalls._

def get_f(f: Int => Int, x: Int, y: Int): TailRec[Int] = x match {
  case 0 => done(y)
  case _ => tailcall(get_f(f, x - 1, y)).map(f)
}

val answer = get_f(_+1, 0, 24).result

You can read about TailRec here or for more advanced study, this paper.

wheaties
  • 35,646
  • 15
  • 94
  • 131
2

Let's start with reducing number of parameters from your non-tailrec version to make it clear what it actually does:

def get_f(f: Int => Int, x: Int, y: Int) = {
  def get_f_impl(x: Int): Int = x match {
    case 0 => y
    case _ => f(get_f_impl(x - 1))
  }
  get_f_impl(x)
}

The idea is that actually you apply f-function x-times to initial value y. So, it becomes clear that you can do something like this in order to make it tail-recursive:

def get_f(f: Int => Int, x: Int, y: Int) = {
  @tailrec def get_f_impl(acc: Int, x: Int): Int = 
    if (x == 0) acc else get_f_impl(f(acc), x - 1) 
  get_f_impl(y, x)
}

REPL investigation:

Your original implementation:

scala> get_f(_ + 1, 4, 0)
res6: Int = 4

Your implementation (with params optimisation):

scala> get_f(_ + 1, 4, 0)
res0: Int = 4

Tailrec implementation:

scala> get_f(_ + 1, 4, 0)
res3: Int = 4

P.S. For more complex cases trampolines might fit: https://espinhogr.github.io/scala/2015/07/12/trampolines-in-scala.html

P.S.2 You can also try:

Community
  • 1
  • 1
dk14
  • 22,206
  • 4
  • 51
  • 88
  • It's not really tail recursive, since `get_f_impl` calls `get_f`. – adamwy Dec 06 '16 at 14:46
  • Sorry . I meant get_f_impl – dk14 Dec 06 '16 at 14:47
  • I didn't check the code as I don't have scala REPL right now, but the idea is that function just applies `f` n times – dk14 Dec 06 '16 at 14:48
  • I just checked it and it doesn't compile. `f(get_f_impl(...))` is not tail recursive. – adamwy Dec 06 '16 at 14:48
  • You checked the wrong one - second example should compile. the first one is a simplification of original non-tail-recursive one – dk14 Dec 06 '16 at 14:49
  • I think it's wrong, because you call `get_f_impl` with result of `f(x - 1)` not `x - 1`. – adamwy Dec 06 '16 at 14:55
  • x is just a counter to apply a function `f` x times – dk14 Dec 06 '16 at 14:56
  • Yes, but you turn this counter into a value after transforming it using `f`. – adamwy Dec 06 '16 at 14:57
  • Corrected to pass accumulator too – dk14 Dec 06 '16 at 15:00
  • `Function.chain` solution is not very efficient, since it requires to allocate a list of size x. – adamwy Dec 06 '16 at 15:18
  • Your `Function.chain` solution can blow the stack because `compose` generates a function requiring 3 stack frames. Scala can not do tail call elimination outside the context of a tail recursive function. – wheaties Dec 06 '16 at 15:22
  • wheaties, `Function.chain` will not blow the stack since it doesn't use `compose` internally, it's implemented using `foldRight` on input Seq: https://github.com/scala/scala/blob/v2.12.0/src/library/scala/Function.scala#L24 – adamwy Dec 06 '16 at 15:27
2

I'll add that you can achieve the same result by using foldLeft on Range, like this:

def get_f(f: Int => Int, x: Int, y: Int) =
  (0 until x).foldLeft(y)((acc, _) => f(acc))
adamwy
  • 1,239
  • 1
  • 7
  • 12
  • That's not true, see: https://github.com/scala/scala/blob/v2.12.1/src/library/scala/collection/TraversableOnce.scala#L161 but `foldRight` creates a List of n elements internally (as it calls `foldLeft` on `reversed` iterator), so it's best to use `foldLeft`. – adamwy Dec 06 '16 at 16:35
0

In line with previous responses

 def get_f2( f: Int => Int, x: Int, y: Int) : Int = {
   def tail(y: Int, x: Int)(f: Int => Int) : Int = {
     x match {
       case 0 => y
       case _ => tail(f(y), x - 1)(f) : Int
     }
   }

   tail(y, x)(f)
 }
Emiliano Martinez
  • 4,073
  • 2
  • 9
  • 19