61

Is it possible to implement in Scala something equivalent to the Python yield statement where it remembers the local state of the function where it is used and "yields" the next value each time it is called?

I wanted to have something like this to convert a recursive function into an iterator. Sort of like this:

# this is python
def foo(i):
  yield i
  if i > 0:
    for j in foo(i - 1):
      yield j

for i in foo(5):
  print i

Except, foo may be more complex and recurs through some acyclic object graph.

Additional Edit: Let me add a more complex example (but still simple): I can write a simple recursive function printing things as it goes along:

// this is Scala
def printClass(clazz:Class[_], indent:String=""): Unit = {
  clazz match {
    case null =>
    case _ =>
      println(indent + clazz)
      printClass(clazz.getSuperclass, indent + "  ")
      for (c <- clazz.getInterfaces) {
        printClass(c, indent + "  ")
      }
  }
}

Ideally I would like to have a library that allows me to easily change a few statements and have it work as an Iterator:

// this is not Scala
def yieldClass(clazz:Class[_]): Iterator[Class[_]] = {
  clazz match {
    case null =>
    case _ =>
      sudoYield clazz
      for (c <- yieldClass(clazz.getSuperclass)) sudoYield c
      for (c <- clazz.getInterfaces; d <- yieldClasss(c)) sudoYield d
  }
}

It does seem continuations allow to do that, but I just don't understand the shift/reset concept. Will continuation eventually make it into the main compiler and would it be possible to extract out the complexity in a library?

Edit 2: check Rich's answer in that other thread.

Community
  • 1
  • 1
huynhjl
  • 41,520
  • 14
  • 105
  • 158
  • It is difficult to come up with a tractable example that could not be implemented via standard techniques. For example, I think your `yieldClass` example could be implemented by just using `Iterator.++` cleverly. But, yes, I think `yieldClass` could be implemented in terms of shift/reset. I don't know when it will make it into the compiler without requiring a plugin. I think most complexity can be factored out into a "generator library". I think Rich Dougherty's blog is the best explanatory source of these critters. – Mitch Blevins Jan 26 '10 at 07:12
  • 1
    You are right about Iterator.++. http://gist.github.com/286682 works. I'll have to check Rich's blog. – huynhjl Jan 26 '10 at 08:44
  • This is a duplicate, though, curiously, I don't see the original listed on "related". – Daniel C. Sobral Jan 26 '10 at 12:45
  • 2
    Sorry for the OT, but I stared at your sudoYield for a few moments until I realized you probably meant pseudoYield. Pseudo = sham or pretend; sudo = super user do (a linux command). See this: http://dictionary.reference.com/browse/pseudo And this: https://xkcd.com/149/ – elifiner Jul 12 '15 at 13:59

5 Answers5

34

While Python generators are cool, trying to duplicate them really isn't the best way to go about in Scala. For instance, the following code does the equivalent job to what you want:

def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
  case null => Stream.empty
  case _ => (
    clazz 
    #:: classStream(clazz.getSuperclass) 
    #::: clazz.getInterfaces.toStream.flatMap(classStream) 
    #::: Stream.empty
  )
}

In it the stream is generated lazily, so it won't process any of the elements until asked for, which you can verify by running this:

def classStream(clazz: Class[_]): Stream[Class[_]] = clazz match {
  case null => Stream.empty
  case _ => (
    clazz 
    #:: { println(clazz.toString+": super"); classStream(clazz.getSuperclass) } 
    #::: { println(clazz.toString+": interfaces"); clazz.getInterfaces.toStream.flatMap(classStream) } 
    #::: Stream.empty
  )
}

The result can be converted into an Iterator simply by calling .iterator on the resulting Stream:

def classIterator(clazz: Class[_]): Iterator[Class[_]] = classStream(clazz).iterator

The foo definition, using Stream, would be rendered thus:

scala> def foo(i: Int): Stream[Int] = i #:: (if (i > 0) foo(i - 1) else Stream.empty)
foo: (i: Int)Stream[Int]

scala> foo(5) foreach println
5
4
3
2
1
0

Another alternative would be concatenating the various iterators, taking care to not pre-compute them. Here's an example, also with debugging messages to help trace the execution:

def yieldClass(clazz: Class[_]): Iterator[Class[_]] = clazz match {
  case null => println("empty"); Iterator.empty
  case _ =>
    def thisIterator = { println("self of "+clazz); Iterator(clazz) }
    def superIterator = { println("super of "+clazz); yieldClass(clazz.getSuperclass) }
    def interfacesIterator = { println("interfaces of "+clazz); clazz.getInterfaces.iterator flatMap yieldClass }
    thisIterator ++ superIterator ++ interfacesIterator
}

This is pretty close to your code. Instead of sudoYield, I have definitions, and then I just concatenate them as I wish.

So, while this is a non-answer, I just think you are barking up the wrong tree here. Trying to write Python in Scala is bound to be unproductive. Work harder at the Scala idioms that accomplish the same goals.

Nathaniel Ford
  • 20,545
  • 20
  • 91
  • 102
Daniel C. Sobral
  • 295,120
  • 86
  • 501
  • 681
  • 2
    Thanks, I think the Stream solution is what I was looking for as it evaluates lazily. You are right, I don't want to write python in Scala, but since I never used streams before the solution did not occur to me. – huynhjl Jan 26 '10 at 15:09
  • 1
    By the way, where is the scaladoc for the #:: and #::: operators? I can't seem to see it on the scala.collection.immutable.Stream scaladoc. – huynhjl Jan 26 '10 at 15:26
  • @huynhjl Both the `Stream` and the `Iterator` solutions evaluate lazily. As for these operators, they are only present on Scala 2.8. They are not defined in an obvious place indeed, because it just wouldn't work. I can't recall -- or find -- where they are defined right now. You can replace them with `Stream.cons` and `Stream.concat` if you are using Scala 2.7. – Daniel C. Sobral Jan 26 '10 at 21:03
  • The downside for Stream (and similar recursive constructs) is that working with them in Scala easily leads to stack overflows -- this is what makes trampolines appealing. http://blog.richdougherty.com/2009/04/tail-calls-tailrec-and-trampolines.html – Yang Feb 07 '10 at 22:18
  • 1
    #:: and #::: are in the scaladoc for the Stream companion object (http://www.scala-lang.org/api/beta/scala/collection/immutable/Stream$.html), but are unfortunately not very well documented. #:: does Stream.cons, and #::: does Stream.concat. – Todd Owen Nov 06 '11 at 22:55
  • @Yang: because streams are lazy, they do not consume stack in the way that true recursion might. In this respect, they are very much like a python generator. – Todd Owen Nov 06 '11 at 22:59
  • @Daniel C. Sobral: One difference between the Python-yield solution and your stream-solution is that a Scala `Stream` never is lazy in its head element, it instead always stores the head in evaluated form. Your iterator solution also does that. But could the iterator solution be changed so that it doesn't compute the first element until it is requested by the client? – Lii Aug 06 '16 at 17:08
  • @Lii A `def` isn't computed until called, so the head won't be computed until you call that method. And there are other solutions to keep it lazy if you want to pass it around (`lazy val`, by name parameters, 0-arity functions). You can make it so simply by converting the return from `Stream` to `() => Stream`. – Daniel C. Sobral Aug 06 '16 at 19:54
  • I was hoping there were a better way. It's not pretty to pass a `() => Stream` around instead of a `Stream`. In the iterator case the best solution I could come up with is this: Make a wrapper to `yieldClass` which returns a special implementation of `Iterator`. That implementation takes a `() => Iterator`, which it calls on first access. – Lii Aug 06 '16 at 20:43
  • That solution doesn't work for `Stream` however, since extending `Stream` is deprecated. And I would much rather work with the reliably immutable `Stream` than the treacherous `Iterator`. – Lii Aug 06 '16 at 21:09
  • @Lii I'm not talking about extending anything. Just this: `def classStream(clazz: Class[_]): () => Stream[Class[_]] = () => clazz match {`. – Daniel C. Sobral Aug 07 '16 at 03:18
  • I understand that. I was considering extending `Stream` to avoid exactly that. Because I'd rather avoid using the type `() => Stream[Class[_]]` instead of `Stream[Class[_]]` all over my application. With `Iterator` I could avoid that, but not with `Stream`. – Lii Aug 07 '16 at 07:55
  • @Lii Well, then define your own `Stream` or pick one from Scalaz (or maybe Cats, I don't know if they have a stream). – Daniel C. Sobral Aug 07 '16 at 09:30
  • I think the big difference between Python generators and scala streams (or lazy lists) is that the latter memoize (and use memory to do so). So Streams often are not a good analogue to Python generators. – Stephane Bersier Jul 10 '23 at 14:16
  • That's only true if you keep a reference to the head of the stream. If you turned it into an `Iterator`, it won't happen at all. It won't happen in the `foo` example either. But, to be clear, I was never comparing Python generators to Scala streams. I'm saying that you can't do Python generators in Scala and should, instead, use something like streams. – Daniel C. Sobral Jul 11 '23 at 22:55
12

Another continuations plugin based solution, this time with a more or less encapsulated Generator type,

import scala.continuations._
import scala.continuations.ControlContext._

object Test {

  def loopWhile(cond: =>Boolean)(body: =>(Unit @suspendable)): Unit @suspendable = {
    if (cond) {
      body
      loopWhile(cond)(body)
    } else ()
  }

  abstract class Generator[T] {
    var producerCont : (Unit => Unit) = null
    var consumerCont : (T => Unit) = null

    protected def body : Unit @suspendable

    reset {
      body
    }

    def generate(t : T) : Unit @suspendable =
      shift {
        (k : Unit => Unit) => {
          producerCont = k
          if (consumerCont != null)
            consumerCont(t)
        }
      }

    def next : T @suspendable =
      shift {
        (k : T => Unit) => {
          consumerCont = k
          if (producerCont != null)
            producerCont()
        }
      }
  }

  def main(args: Array[String]) {
    val g = new Generator[Int] {
      def body = {
        var i = 0
        loopWhile(i < 10) {
          generate(i)
          i += 1
        }
      }
    }

    reset {
      loopWhile(true) {
        println("Generated: "+g.next)
      }
    }
  }
}
Miles Sabin
  • 23,015
  • 6
  • 61
  • 95
  • Thank you Miles. I'm looking forward to try this. I'll need to spend some time to set up the continuation plug-in first... – huynhjl Jan 27 '10 at 15:04
  • I was able to compile and run your sample. It will probably take me some time and some documentation to be able to modify and understand it. – huynhjl Jan 28 '10 at 07:43
  • 1
    This was informative in learning of ways to do things with delimited continuations, but the downside to this particular solution is that the call site has to be CPS-transformed. Rich Dougherty and I present alternative solutions at http://stackoverflow.com/questions/2201882/implementing-yield-yield-return-using-scala-continuations/. – Yang Feb 07 '10 at 22:15
  • Yes, I agree that Rich's is a much nicer solution ... much more direct. Mine is actually derived from an encoding of symmetric coroutines using shift/reset and I think that shows through in the awkwardness you point out. – Miles Sabin Feb 08 '10 at 00:10
4

To do this in a general way, I think you need the continuations plugin.

A naive implementation (freehand, not compiled/checked):

def iterator = new {
  private[this] var done = false

  // Define your yielding state here
  // This generator yields: 3, 13, 0, 1, 3, 6, 26, 27
  private[this] var state: Unit=>Int = reset {
    var x = 3
    giveItUp(x)
    x += 10
    giveItUp(x)
    x = 0
    giveItUp(x)
    List(1,2,3).foreach { i => x += i; giveItUp(x) }
    x += 20
    giveItUp(x)
    x += 1
    done = true
    x
  }

  // Well, "yield" is a keyword, so how about giveItUp?
  private[this] def giveItUp(i: Int) = shift { k: (Unit=>Int) =>
    state = k
    i
  }

  def hasNext = !done
  def next = state()
}

What is happening is that any call to shift captures the control flow from where it is called to the end of the reset block that it is called in. This is passed as the k argument into the shift function.

So, in the example above, each giveItUp(x) returns the value of x (up to that point) and saves the rest of the computation in the state variable. It is driven from outside by the hasNext and next methods.

Go gentle, this is obviously a terrible way to implement this. But it best I could do late at night without a compiler handy.

Mitch Blevins
  • 13,186
  • 3
  • 44
  • 32
  • I think a library might be made if the shift/reset generated a stream, so each call would go back to the shift/reset. I think. Sort of. – Daniel C. Sobral Jan 26 '10 at 12:57
  • 1
    The blog is in the link in my answer above: http://blog.richdougherty.com/search/label/continuations – Mitch Blevins Jan 26 '10 at 16:10
  • I get an "error: type mismatch" where `found: scala.runtime.StringAdd @scala.continuations.uncps @scala.continuations.cps[Int,Int]` and `required: ? @scala.continuations.cps[?,(Unit) => Int]` on the line `private[this] var state: Unit=>Int = reset {`. – Yang Feb 06 '10 at 17:17
4

Scala's for-loop of the form for (e <- Producer) f(e) translates into a foreach call, and not directly into iterator / next.

In the foreach we don't need to linearize objects' creations and have them in one place, as it would be needed for iterator's next. The consumer-function f can be inserted multiple times, exactly where it is needed (i.e. where an object is created).

This makes implementation of use cases for generators simple and efficient with Traversable / foreach in Scala.


The initial Foo-example:

case class Countdown(start: Int) extends Traversable[Int] {
    def foreach[U](f: Int => U) {
        var j = start
        while (j >= 0) {f(j); j -= 1}
    }
}

for (i <- Countdown(5))  println(i)
// or equivalent:
Countdown(5) foreach println

The initial printClass-example:

  // v1 (without indentation)

  case class ClassStructure(c: Class[_]) {
    def foreach[U](f: Class[_] => U) {
      if (c eq null) return
      f(c)
      ClassStructure(c.getSuperclass) foreach f
      c.getInterfaces foreach (ClassStructure(_) foreach f)
    }
  }

  for (c <- ClassStructure(<foo/>.getClass)) println(c)
  // or equivalent:
  ClassStructure(<foo/>.getClass) foreach println

Or with indentation:

  // v2 (with indentation)

  case class ClassWithIndent(c: Class[_], indent: String = "") {
    override def toString = indent + c
  }
  implicit def Class2WithIndent(c: Class[_]) = ClassWithIndent(c)

  case class ClassStructure(cwi: ClassWithIndent) {
    def foreach[U](f: ClassWithIndent => U) {
      if (cwi.c eq null) return
      f(cwi)
      ClassStructure(ClassWithIndent(cwi.c.getSuperclass, cwi.indent + "  ")) foreach f
      cwi.c.getInterfaces foreach (i => ClassStructure(ClassWithIndent(i, cwi.indent + "  ")) foreach f)
    }
  }

  for (c <- ClassStructure(<foo/>.getClass)) println(c)
  // or equivalent:
  ClassStructure(<foo/>.getClass) foreach println

Output:

class scala.xml.Elem
  class scala.xml.Node
    class scala.xml.NodeSeq
      class java.lang.Object
      interface scala.collection.immutable.Seq
        interface scala.collection.immutable.Iterable
          interface scala.collection.immutable.Traversable
            interface scala.collection.Traversable
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.collection.generic.GenericTraversableTemplate
                interface scala.collection.generic.HasNewBuilder
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.collection.generic.GenericTraversableTemplate
              interface scala.collection.generic.HasNewBuilder
              interface scala.ScalaObject
            interface scala.collection.TraversableLike
              interface scala.collection.generic.HasNewBuilder
              interface scala.collection.generic.FilterMonadic
              interface scala.collection.TraversableOnce
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.Immutable
            interface scala.ScalaObject
          interface scala.collection.Iterable
            interface scala.collection.Traversable
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.collection.generic.GenericTraversableTemplate
                interface scala.collection.generic.HasNewBuilder
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.collection.generic.GenericTraversableTemplate
              interface scala.collection.generic.HasNewBuilder
              interface scala.ScalaObject
            interface scala.collection.IterableLike
              interface scala.Equals
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.collection.generic.GenericTraversableTemplate
            interface scala.collection.generic.HasNewBuilder
            interface scala.ScalaObject
          interface scala.collection.IterableLike
            interface scala.Equals
            interface scala.collection.TraversableLike
              interface scala.collection.generic.HasNewBuilder
              interface scala.collection.generic.FilterMonadic
              interface scala.collection.TraversableOnce
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.ScalaObject
        interface scala.collection.Seq
          interface scala.PartialFunction
            interface scala.Function1
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.collection.Iterable
            interface scala.collection.Traversable
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.collection.generic.GenericTraversableTemplate
                interface scala.collection.generic.HasNewBuilder
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.collection.generic.GenericTraversableTemplate
              interface scala.collection.generic.HasNewBuilder
              interface scala.ScalaObject
            interface scala.collection.IterableLike
              interface scala.Equals
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.collection.generic.GenericTraversableTemplate
            interface scala.collection.generic.HasNewBuilder
            interface scala.ScalaObject
          interface scala.collection.SeqLike
            interface scala.collection.IterableLike
              interface scala.Equals
              interface scala.collection.TraversableLike
                interface scala.collection.generic.HasNewBuilder
                interface scala.collection.generic.FilterMonadic
                interface scala.collection.TraversableOnce
                  interface scala.ScalaObject
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.ScalaObject
        interface scala.collection.generic.GenericTraversableTemplate
          interface scala.collection.generic.HasNewBuilder
          interface scala.ScalaObject
        interface scala.collection.SeqLike
          interface scala.collection.IterableLike
            interface scala.Equals
            interface scala.collection.TraversableLike
              interface scala.collection.generic.HasNewBuilder
              interface scala.collection.generic.FilterMonadic
              interface scala.collection.TraversableOnce
                interface scala.ScalaObject
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.ScalaObject
        interface scala.ScalaObject
      interface scala.collection.SeqLike
        interface scala.collection.IterableLike
          interface scala.Equals
          interface scala.collection.TraversableLike
            interface scala.collection.generic.HasNewBuilder
            interface scala.collection.generic.FilterMonadic
            interface scala.collection.TraversableOnce
              interface scala.ScalaObject
            interface scala.ScalaObject
          interface scala.ScalaObject
        interface scala.ScalaObject
      interface scala.xml.Equality
        interface scala.Equals
        interface scala.ScalaObject
      interface scala.ScalaObject
    interface scala.ScalaObject
  interface scala.ScalaObject
  interface java.io.Serializable
Eugen Labun
  • 2,561
  • 1
  • 22
  • 18
0

Dsl.scala is what you are looking for.

Suppose you want to create a random number generator. The generated numbers should be stored in a lazily evaluated infinite stream, which can be built with the help of our built-in domain-specific keyword Yield.

import com.thoughtworks.dsl.keys.Yield
def xorshiftRandomGenerator(seed: Int): Stream[Int] = {
  val tmp1 = seed ^ (seed << 13)
  val tmp2 = tmp1 ^ (tmp1 >>> 17)
  val tmp3 = tmp2 ^ (tmp2 << 5)
  !Yield(tmp3)
  xorshiftRandomGenerator(tmp3)
}

Other examples can be found in the Scaladoc.

Yang Bo
  • 3,586
  • 3
  • 22
  • 35