14

I'm very new to Scala, so forgive my ignorance! I'm trying to iterate of pairs of integers that are bounded by a maximum. For example, if the maximum is 5, then the iteration should return:

(0, 0), (0, 1), ..., (0, 5), (1, 0), ..., (5, 5)

I've chosen to try and tail-recursively return this as a Stream:

    @tailrec
    def _pairs(i: Int, j: Int, maximum: Int): Stream[(Int, Int)] = {
        if (i == maximum && j == maximum) Stream.empty
        else if (j == maximum) (i, j) #:: _pairs(i + 1, 0, maximum)
        else (i, j) #:: _pairs(i, j + 1, maximum)
    }

Without the tailrec annotation the code works:

scala> _pairs(0, 0, 5).take(11)
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((0,0), ?)

scala> _pairs(0, 0, 5).take(11).toList
res17: List[(Int, Int)] = List((0,0), (0,1), (0,2), (0,3), (0,4), (0,5), (1,0), (1,1), (1,2), (1,3), (1,4))

But this isn't good enough for me. The compiler is correctly pointing out that the last line of _pairs is not returning _pairs:

could not optimize @tailrec annotated method _pairs: it contains a recursive call not in tail position
    else (i, j) #:: _pairs(i, j + 1, maximum)
                ^

So, I have several questions:

  • directly addressing the implementation above, how does one tail-recursively return Stream[(Int, Int)]?
  • taking a step back, what is the most memory-efficient way to iterate over bounded sequences of integers? I don't want to iterate over Range because Range extends IndexedSeq, and I don't want the sequence to exist entirely in memory. Or am I wrong? If I iterate over Range.view do I avoid it coming into memory?

In Python (!), all I want is:

In [6]: def _pairs(maximum):
   ...:     for i in xrange(maximum+1):
   ...:         for j in xrange(maximum+1):
   ...:             yield (i, j)
   ...:             

In [7]: p = _pairs(5)

In [8]: [p.next() for i in xrange(11)]
Out[8]: 
[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4)]

Thanks for your help! If you think I need to read references / API docs / anything else please tell me, because I'm keen to learn.

Will Ness
  • 70,110
  • 9
  • 98
  • 181
Asim Ihsan
  • 1,501
  • 8
  • 18

2 Answers2

28

This is not tail-recursion

Let's suppose you were making a list instead of a stream: (let me use a simpler function to make my point)

def foo(n: Int): List[Int] =
  if (n == 0)
    0 :: Nil
  else
    n :: foo(n - 1)

In the general case in this recursion, after foo(n - 1) returns the function has to do something with the list that it returns -- it has to concatenate another item onto the beginning of the list. So the function can't be tail recursive, becuase something has to be done to the list after the recursion.

Without tail recursion, for some large value of n, you run out of stack space.

The usual list solution

The usual solution would be to pass a ListBuffer as a second parameter, and fill that.

def foo(n: Int) = {
  def fooInternal(n: Int, list: ListBuffer[Int]) = {
    if (n == 0) 
      list.toList
    else {
      list += n
      fooInternal(n - 1, list)
    }
  }
  fooInternal(n, new ListBuffer[Int]())
}

What you're doing is known as "tail recursion modulo cons", and this is an optimization performed automatically by LISP Prolog compilers when they see the tail recursion modulo cons pattern, since it's so common. Scala's compiler does not optimize this automatically.

Streams don't need tail recursion

Streams don't need tail recursion to avoid running out of stack space -- this is becuase they use a clever trick to keep from executing the recursive call to foo at the point where it appears in the code. The function call gets wrapped in a thunk, and only called at the point that you actually try to get the value from the stream. Only one call to foo is active at a time -- it's never recursive.

I've written a previous answer explaining how the #:: operator works here on Stackoverflow. Here's what happens when you call the following recursive stream function. (It is recursive in the mathematical sense, but it doesn't make a function call from within a function call the way you usually expect.)

def foo(n: Int): Stream[Int] =
  if (n == 0)
    0 #:: Nil
  else
    n #:: foo(n - 1)

You call foo(10), it returns a stream with one element computed already, and the tail is a thunk that will call foo(9) the next time you need an element from the stream. foo(9) is not called right now -- rather the call is bound to a lazy val inside the stream, and foo(10) returns immediately. When you finally do need the second value in the stream, foo(9) is called, and it computes one element and sets the tail of hte stream to be a thunk that will call foo(8). foo(9) returns immediately without calling foo(8). And so on...

This allows you to create infinite streams without running out of memory, for example:

def countUp(start: Int): Stream[Int] = start #::countUp(start + 1)

(Be careful what operations you call on this stream. If you try to do a forEach or a map, you'll fill up your whole heap, but using take is a good way to work with an arbitrary prefix of the stream.)

A simpler solution altogether

Instead of dealing with recursion and streams, why not just use Scala's for loop?

def pairs(maximum:Int) =
  for (i <- 0 to maximum;
       j <- 0 to maximum)
    yield (i, j)

This materializes the entire collection in memory, and returns an IndexedSeq[(Int, Int)].

If you need a Stream specifically, you can convert the first range into a Stream.

def pairs(maximum:Int) =
  for (i <- 0 to maximum toStream;
       j <- 0 to maximum)
    yield (i, j)

This will return a Stream[(Int, Int)]. When you access a certain point in the sequence, it will be materialized into memory, and it will stick around as long as you still have a reference to any point in the stream before that element.

You can get even better memory usage by converting both ranges into views.

def pairs(maximum:Int) =
  for (i <- 0 to maximum view;
       j <- 0 to maximum view)
    yield (i, j)

That returns a SeqView[(Int, Int),Seq[_]] that computes each element each time you need it, and doesn't store precomputed results.

You can also get an iterator (which you can only traverse once) the same way

def pairs(maximum:Int) =
  for (i <- 0 to maximum iterator;
       j <- 0 to maximum iterator)
    yield (i, j)

That returns Iterator[(Int, Int)].

Community
  • 1
  • 1
Ken Bloom
  • 57,498
  • 14
  • 111
  • 168
  • Thank you for your answer! I understand why what I did isn't tail recursive, and I'd definitely prefer to use `for`. The problem I have is that `pairs`, as you've suggested, returns `IndexedSeq`. Hence the whole result will exist in memory when `pairs` is called. Could you please elaborate on how to use views to avoid this? – Asim Ihsan May 09 '12 at 23:44
  • And do you have more details and references about Streams and thunks? I'm very curious about how I'm not going to blow the stack by recursively calling a non-tail-call optimised function where I don't use coroutines. So much to learn! – Asim Ihsan May 09 '12 at 23:48
  • 1
    +1 for the nice answer. Just one more remark: You can actually safely call `map` on the `countUp` stream, because the result will be a `Stream` again. Only the `foreach` call will have eager evaluation. – Frank May 10 '12 at 06:28
  • 3
    Wow, I really had no idea how `Range` works. Checking out the source code, https://github.com/scala/scala/blob/master/src/library/scala/collection/immutable/Range.scala, it's clear that they're lazy. Hence both `(0 to 10)` and `(0 to 10000000)` have the same memory occupancy (three `Int`s). Hence `Range view` or `Range iterator` are delightful answers, where `Iterator` tells callers "you can traverse the result once", and `View` tells callers "you can treat this like a real collection". – Asim Ihsan May 10 '12 at 10:06
  • @AsimIhsan: that's correct. `Range.map`, however materializes the whole collection, and that's what's going on in the `for` loop wihtout calling `view` or `iterator` first. (Scala 2.7 used to perform `Range.map` lazily, but that behavior was found to be surprising and too confusing, so it was changed.) – Ken Bloom May 10 '12 at 12:46
  • @AsimIhsan: I just looked in the code. A `Range` has more like 7 `Int`s in its internal representation, but yeah it's still constant space. – Ken Bloom May 11 '12 at 01:02
  • @Ken Bloom Whoops! I can't read :). I misread the following sentence on this page http://docs.scala-lang.org/overviews/collections/concrete-immutable-collection-classes.html#ranges: "Ranges are represented in constant space, because they can be defined by just three numbers: their start, their end, and the stepping value." I mistook constant space for three `Int`s. Thanks for clearing this up, and your amazing answer. – Asim Ihsan May 11 '12 at 14:06
  • You just saved Scala for me, thanks, I had a hard time finding out about .iterator()... – dividebyzero Dec 14 '14 at 16:35
2

Maybe an Iterator is better suited for you?

class PairIterator (max: Int) extends Iterator [(Int, Int)] {
  var count = -1
  def hasNext = count <= max * max 
  def next () = { count += 1; (count / max, count % max) }
}

val pi = new PairIterator (5)
pi.take (7).toList 
user unknown
  • 35,537
  • 11
  • 75
  • 121
  • By the way thanks for sharing this with me. I've been using Iterators for a lot of other problems and this is the only full example I can find! – Asim Ihsan May 11 '12 at 14:05