98

What's the best way to terminate a fold early? As a simplified example, imagine I want to sum up the numbers in an Iterable, but if I encounter something I'm not expecting (say an odd number) I might want to terminate. This is a first approximation

def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
  nums.foldLeft (Some(0): Option[Int]) {
    case (Some(s), n) if n % 2 == 0 => Some(s + n)
    case _ => None
  }
}

However, this solution is pretty ugly (as in, if I did a .foreach and a return -- it'd be much cleaner and clearer) and worst of all, it traverses the entire iterable even if it encounters a non-even number.

So what would be the best way to write a fold like this, that terminates early? Should I just go and write this recursively, or is there a more accepted way?

Heptic
  • 3,076
  • 4
  • 30
  • 51

11 Answers11

71

My first choice would usually be to use recursion. It is only moderately less compact, is potentially faster (certainly no slower), and in early termination can make the logic more clear. In this case you need nested defs which is a little awkward:

def sumEvenNumbers(nums: Iterable[Int]) = {
  def sumEven(it: Iterator[Int], n: Int): Option[Int] = {
    if (it.hasNext) {
      val x = it.next
      if ((x % 2) == 0) sumEven(it, n+x) else None
    }
    else Some(n)
  }
  sumEven(nums.iterator, 0)
}

My second choice would be to use return, as it keeps everything else intact and you only need to wrap the fold in a def so you have something to return from--in this case, you already have a method, so:

def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
  Some(nums.foldLeft(0){ (n,x) =>
    if ((n % 2) != 0) return None
    n+x
  })
}

which in this particular case is a lot more compact than recursion (though we got especially unlucky with recursion since we had to do an iterable/iterator transformation). The jumpy control flow is something to avoid when all else is equal, but here it's not. No harm in using it in cases where it's valuable.

If I was doing this often and wanted it within the middle of a method somewhere (so I couldn't just use return), I would probably use exception-handling to generate non-local control flow. That is, after all, what it is good at, and error handling is not the only time it's useful. The only trick is to avoid generating a stack trace (which is really slow), and that's easy because the trait NoStackTrace and its child trait ControlThrowable already do that for you. Scala already uses this internally (in fact, that's how it implements the return from inside the fold!). Let's make our own (can't be nested, though one could fix that):

import scala.util.control.ControlThrowable
case class Returned[A](value: A) extends ControlThrowable {}
def shortcut[A](a: => A) = try { a } catch { case Returned(v) => v }

def sumEvenNumbers(nums: Iterable[Int]) = shortcut{
  Option(nums.foldLeft(0){ (n,x) =>
    if ((x % 2) != 0) throw Returned(None)
    n+x
  })
}

Here of course using return is better, but note that you could put shortcut anywhere, not just wrapping an entire method.

Next in line for me would be to re-implement fold (either myself or to find a library that does it) so that it could signal early termination. The two natural ways of doing this are to not propagate the value but an Option containing the value, where None signifies termination; or to use a second indicator function that signals completion. The Scalaz lazy fold shown by Kim Stebel already covers the first case, so I'll show the second (with a mutable implementation):

def foldOrFail[A,B](it: Iterable[A])(zero: B)(fail: A => Boolean)(f: (B,A) => B): Option[B] = {
  val ii = it.iterator
  var b = zero
  while (ii.hasNext) {
    val x = ii.next
    if (fail(x)) return None
    b = f(b,x)
  }
  Some(b)
}

def sumEvenNumbers(nums: Iterable[Int]) = foldOrFail(nums)(0)(_ % 2 != 0)(_ + _)

(Whether you implement the termination by recursion, return, laziness, etc. is up to you.)

I think that covers the main reasonable variants; there are some other options also, but I'm not sure why one would use them in this case. (Iterator itself would work well if it had a findOrPrevious, but it doesn't, and the extra work it takes to do that by hand makes it a silly option to use here.)

Rex Kerr
  • 166,841
  • 26
  • 322
  • 407
  • The `foldOrFail` is exactly what I had come up with when thinking about the question. No reason not to use a mutable iterator and a while loop in the implementation IMO, when all nicely encapsulated. Using `iterator` along with recursion doesn't make sense. – 0__ Oct 15 '12 at 15:37
  • @Rex Kerr, thanks for your answer I tweaked a version for my own use that uses Either... (I'm going to post it as an answer) – Core Mar 22 '13 at 20:45
  • Probably one of the cons of _return_-based solution, is that it takes a while to realize which function it applies to: `sumEvenNumbers` or fold's `op` – Ivan Balashov Feb 10 '15 at 16:43
  • 1
    @IvanBalashov - Well, it takes a while _once_ to learn what Scala's rules are for `return` (i.e., it returns from innermost explicit method you find it in), but after that it ought not take very long. The rule is pretty clear, and the `def` gives away where the enclosing method is. – Rex Kerr Feb 10 '15 at 22:59
  • 1
    I like your foldOrFail but personally I would have made the return type `B` not `Option[B]` because then it behaves like fold where the return type is the same as the zero accumulator's type. The simply replace all the Option returns with b. and pas in None as the zero. After all the question wanted a fold that can terminate early, rather than fail. – Karl Jun 12 '18 at 23:29
28

The scenario you describe (exit upon some unwanted condition) seems like a good use case for the takeWhile method. It is essentially filter, but should end upon encountering an element that doesn't meet the condition.

For example:

val list = List(2,4,6,8,6,4,2,5,3,2)
list.takeWhile(_ % 2 == 0) //result is List(2,4,6,8,6,4,2)

This will work just fine for Iterators/Iterables too. The solution I suggest for your "sum of even numbers, but break on odd" is:

list.iterator.takeWhile(_ % 2 == 0).foldLeft(...)

And just to prove that it's not wasting your time once it hits an odd number...

scala> val list = List(2,4,5,6,8)
list: List[Int] = List(2, 4, 5, 6, 8)

scala> def condition(i: Int) = {
     |   println("processing " + i)
     |   i % 2 == 0
     | }
condition: (i: Int)Boolean

scala> list.iterator.takeWhile(condition _).sum
processing 2
processing 4
processing 5
res4: Int = 6
Dylan
  • 13,645
  • 3
  • 40
  • 67
15

You can do what you want in a functional style using the lazy version of foldRight in scalaz. For a more in depth explanation, see this blog post. While this solution uses a Stream, you can convert an Iterable into a Stream efficiently with iterable.toStream.

import scalaz._
import Scalaz._

val str = Stream(2,1,2,2,2,2,2,2,2)
var i = 0 //only here for testing
val r = str.foldr(Some(0):Option[Int])((n,s) => {
  println(i)
  i+=1
  if (n % 2 == 0) s.map(n+) else None
})

This only prints

0
1

which clearly shows that the anonymous function is only called twice (i.e. until it encounters the odd number). That is due to the definition of foldr, whose signature (in case of Stream) is def foldr[B](b: B)(f: (Int, => B) => B)(implicit r: scalaz.Foldable[Stream]): B. Note that the anonymous function takes a by name parameter as its second argument, so it need no be evaluated.

Btw, you can still write this with the OP's pattern matching solution, but I find if/else and map more elegant.

Kim Stebel
  • 41,826
  • 12
  • 125
  • 142
7

Well, Scala does allow non local returns. There are differing opinions on whether or not this is a good style.

scala> def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
     |   nums.foldLeft (Some(0): Option[Int]) {
     |     case (None, _) => return None
     |     case (Some(s), n) if n % 2 == 0 => Some(s + n)
     |     case (Some(_), _) => None
     |   }
     | }
sumEvenNumbers: (nums: Iterable[Int])Option[Int]

scala> sumEvenNumbers(2 to 10)
res8: Option[Int] = None

scala> sumEvenNumbers(2 to 10 by 2)
res9: Option[Int] = Some(30)

EDIT:

In this particular case, as @Arjan suggested, you can also do:

def sumEvenNumbers(nums: Iterable[Int]): Option[Int] = {
  nums.foldLeft (Some(0): Option[Int]) {
    case (Some(s), n) if n % 2 == 0 => Some(s + n)
    case _ => return None
  }
}
missingfaktor
  • 90,905
  • 62
  • 285
  • 365
7

Cats has a method called foldM which does short-circuiting (for Vector, List, Stream, ...).

It works as follows:

def sumEvenNumbers(nums: Stream[Int]): Option[Long] = {
  import cats.implicits._
  nums.foldM(0L) {
    case (acc, c) if c % 2 == 0 => Some(acc + c)
    case _ => None
  }
}

If it finds a not even element it returns None without computing the rest, otherwise it returns the sum of the even entries.

If you want to keep count until an even entry is found, you should use an Either[Long, Long]

Didac Montero
  • 2,046
  • 19
  • 27
7

You can use foldM from cats lib (as suggested by @Didac) but I suggest to use Either instead of Option if you want to get actual sum out.

bifoldMap is used to extract the result from Either.

import cats.implicits._

def sumEven(nums: Stream[Int]): Either[Int, Int] = {
    nums.foldM(0) {
      case (acc, n) if n % 2 == 0 => Either.right(acc + n)
      case (acc, n) => {
        println(s"Stopping on number: $n")
        Either.left(acc)
      }
    }
  }

examples:

println("Result: " + sumEven(Stream(2, 2, 3, 11)).bifoldMap(identity, identity))
> Stopping on number: 3
> Result: 4

println("Result: " + sumEven(Stream(2, 7, 2, 3)).bifoldMap(identity, identity))
> Stopping on number: 7
> Result: 2
rozky
  • 2,637
  • 2
  • 25
  • 13
  • Came here to post a similar answer, because this is the most convenient yet still FP way to do in my opinion. I surprised that nobody votes for this. So, grab my +1. (I prefer `(acc + n).asRight` instead of `Either.right(acc + n)` but anyway) – abdolence Apr 21 '20 at 16:15
  • rather than `bifoldMap` just `fold(L => C, R => C): C` will work on `Either[L, R]`, and then you dont need a `Monoid[C]` – Ben Hutchison Nov 09 '20 at 01:24
1

@Rex Kerr your answer helped me, but I needed to tweak it to use Either

  
  def foldOrFail[A,B,C,D](map: B => Either[D, C])(merge: (A, C) => A)(initial: A)(it: Iterable[B]): Either[D, A] = {
    val ii= it.iterator
    var b= initial
    while (ii.hasNext) {
      val x= ii.next
      map(x) match {
        case Left(error) => return Left(error)
        case Right(d) => b= merge(b, d)
      }
    }
    Right(b)
  }
Core
  • 410
  • 2
  • 10
1

You could try using a temporary var and using takeWhile. Here is a version.

  var continue = true

  // sample stream of 2's and then a stream of 3's.

  val evenSum = (Stream.fill(10)(2) ++ Stream.fill(10)(3)).takeWhile(_ => continue)
    .foldLeft(Option[Int](0)){

    case (result,i) if i%2 != 0 =>
          continue = false;
          // return whatever is appropriate either the accumulated sum or None.
          result
    case (optionSum,i) => optionSum.map( _ + i)

  }

The evenSum should be Some(20) in this case.

seagull1089
  • 381
  • 3
  • 4
0

A more beutiful solution would be using span:

val (l, r) = numbers.span(_ % 2 == 0)
if(r.isEmpty) Some(l.sum)
else None

... but it traverses the list two times if all the numbers are even

Arjan
  • 19,957
  • 2
  • 55
  • 48
  • 2
    I like the lateral thinking exemplified by your solution, but it only solves the specific example picked in the question rather than dealing with the general question of how to terminate a fold early. – iainmcgin Oct 15 '12 at 09:57
  • i wanted to show how to do the reverse, not terminating a fold early but only folding (in this case sum) over the values we want to fold over – Arjan Oct 15 '12 at 10:25
0

You can throw a well-chosen exception upon encountering your termination criterion, handling it in the calling code.

waldrumpus
  • 2,540
  • 18
  • 44
0

Just for an "academic" reasons (:

var headers = Source.fromFile(file).getLines().next().split(",")
var closeHeaderIdx = headers.takeWhile { s => !"Close".equals(s) }.foldLeft(0)((i, S) => i+1)

Takes twice then it should but it is a nice one liner. If "Close" not found it will return

headers.size

Another (better) is this one:

var headers = Source.fromFile(file).getLines().next().split(",").toList
var closeHeaderIdx = headers.indexOf("Close")
ozma
  • 1,633
  • 1
  • 20
  • 28