6

Is there any way to define a stream with a backtracking algorithm in Scala ?

For instance, the following backtracking algorithm prints all "binary" strings of a given size.

def binaries(s:String, n:Int) {
  if (s.size == n)
    println(s)
  else {
    binaries(s + '0', n)
    binaries(s + '1', n)
  }
}

I believe I can define a stream of "binary" strings of a given size using another iterative algorithm. However I wonder if I can convert the backtracking algorithm above to a stream.

Michael
  • 10,185
  • 12
  • 59
  • 110

2 Answers2

11

This is pretty straight-forward:

def binaries(s: String, n: Int): Stream[String] = 
  if (s.size == n) Stream(s) 
  else binaries(s + "0", n) append binaries(s + "1", n)

Note the use of append -- this method is non-standard for other collections, which is a requirement because it has to take its parameter by-name.

Daniel C. Sobral
  • 295,120
  • 86
  • 501
  • 681
  • Thanks, this is exactly what I am looking for :) It looks like more efficient (in terms of stack memory consumption) than the original backtracking version, doesn't it? – Michael Dec 25 '11 at 08:48
  • @Michael Perhaps. I'm not comfortable analysis efficiency of code using `Stream`. If I find it necessary to use them, I make sure to test the code on REPL to make sure it doesn't overflow. – Daniel C. Sobral Dec 25 '11 at 13:58
  • @Michael, the stream version is less efficient in terms of stack frame usage. Each recursive call uses 4 stack frames compared to only one in your version. In practice, this shouldn't be a problem for this example. With respect to heap memory usage, the Stream class is one of the least efficient of the scala collection in term of memory usage per reference stored but it is usually ok if you don't accidentally hold onto the head of the stream... – huynhjl Dec 25 '11 at 17:25
3

You can, but it won't evaluate lazily:

def bin(s: String, n: Int): Stream[String] = {
  if (s.length == n) { 
    println("got " + s) // for figuring out when it's evaluated
    Stream(s)
  } else {
    val s0 = s + '0'
    val s1 = s + '1'
    bin(s0, n) ++ bin(s1, n)
  }
}

For instance, when executing bin("", 2).take(2).foreach(println), you'll get the following output:

scala> bin("", 2).take(2).foreach(println)
got 00
got 01
got 10
got 11
00
01

If you don't mind using a TraversableView you can use the conversion described here https://stackoverflow.com/a/3816012/257449. So then the code becomes:

def toTraversable[T]( func : (T => Unit) => Unit) = new Traversable[T] {
  def foreach[X]( f : T => X) = func(f(_) : Unit)                       
}  

def bin2(str: String, n: Int) = {
  def recurse[U](s: String, f: (String) => U) {
    if (s.length == n) { 
      println("got " + s) // for figuring out when it's evaluated
      f(s)
    } else {
      recurse(s + '0', f)
      recurse(s + '1', f)
    }
  }
  toTraversable[String](recurse(str, _)) view
}

Then

scala> bin2("", 2).take(2).foreach(println)
got 00
00
got 01
01
got 10
Community
  • 1
  • 1
huynhjl
  • 41,520
  • 14
  • 105
  • 158