10

How to emulate following behavior in Scala? i.e. keep folding while some certain conditions on the accumulator are met.

def foldLeftWhile[B](z: B, p: B => Boolean)(op: (B, A) => B): B

For example

scala> val seq = Seq(1, 2, 3, 4)
seq: Seq[Int] = List(1, 2, 3, 4)
scala> seq.foldLeftWhile(0, _ < 3) { (acc, e) => acc + e }
res0: Int = 1
scala> seq.foldLeftWhile(0, _ < 7) { (acc, e) => acc + e }
res1: Int = 6

UPDATES:

Based on @Dima answer, I realized that my intention was a little bit side-effectful. So I made it synchronized with takeWhile, i.e. there would be no advancement if the predicate does not match. And add some more examples to make it clearer. (Note: that will not work with Iterators)

ntviet18
  • 752
  • 1
  • 10
  • 26
  • Well, I guess it's not in the standard library because if the condition comes from the accumulator itself, then it can be factored in directly, e.g. `(acc, e) => if (acc < 3) acc+e else acc`. You could even create a higher order function to create the "final" accumulator from its base and condition. I understand this might be less efficient than early-breaking from the fold, but otherwise it looks equivalent. – GPI Dec 11 '18 at 16:48
  • @GPI, agreed. But I think it not just about the efficiency. As we discussed below, the predicate could match another condition later and it will give a likely unexpected result – ntviet18 Dec 11 '18 at 23:12

6 Answers6

8

First, note that your example seems wrong. If I understand correctly what you describe, the result should be 1 (the last value on which the predicate _ < 3 was satisfied), not 6

The simplest way to do this is using a return statement, which is very frowned upon in scala, but I thought, I'd mention it for the sake of completeness.

def foldLeftWhile[A, B](seq: Seq[A], z: B, p: B => Boolean)(op: (B, A) => B): B = foldLeft(z) { case (b, a) => 
   val result = op(b, a) 
   if(!p(result)) return b
   result
}

Since we want to avoid using return, scanLeft might be a possibility:

seq.toStream.scanLeft(z)(op).takeWhile(p).last

This is a little wasteful, because it accumulates all (matching) results. You could use iterator instead of toStream to avoid that, but Iterator does not have .last for some reason, so, you'd have to scan through it an extra time explicitly:

 seq.iterator.scanLeft(z)(op).takeWhile(p).foldLeft(z) { case (_, b) => b }
Dima
  • 39,570
  • 6
  • 44
  • 70
  • Thank you, I think this your first solution is a very good idea but I'm not sure if it will work all the times, e.g. `Seq(1, 2, 2, 1).foldLeftWhile(0, _ < 5) { case (acc, e) => acc + e }` might return 4 instead of 3 – ntviet18 Dec 11 '18 at 20:15
  • I upvoted for your second solution. That's very concise :) – ntviet18 Dec 11 '18 at 20:22
  • I make a quick check. And it turns out your first solution also works like a magic. Could you please explain why the return statement ignores the closure of foldLeft? – ntviet18 Dec 11 '18 at 20:28
  • 1
    @ntviet18 yeah, that's the magic of `return` in scala. It, basically, ignores all lambda's around it, and just returns all the way up to the level above the nearest method. It does it by throwing a special exception, which is caught at the end of the method. This is why it is generally considered "code smell" - too much magic, no referential transparency, and hard to reason about in general – Dima Dec 11 '18 at 21:00
  • Thank you for the explanation – ntviet18 Dec 11 '18 at 21:07
4

It is pretty straightforward to define what you want in scala. You can define an implicit class which will add your function to any TraversableOnce (that includes Seq).

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft(init)((acc, next) => if (where(acc)) op(acc, next) else acc)
  }
}
Seq(1,2,3,4).foldLeftWhile(0)(_ < 3)((acc, e) => acc + e)

Update, since the question was modified:

implicit class FoldLeftWhile[A](trav: TraversableOnce[A]) {
  def foldLeftWhile[B](init: B)(where: B => Boolean)(op: (B, A) => B): B = {
    trav.foldLeft((init, false))((a,b) => if (a._2) a else {
      val r = op(a._1, b)
      if (where(r)) (op(a._1, b), false) else (a._1, true)
    })._1
  }
}

Note that I split your (z: B, p: B => Boolean) into two higher-order functions. That's just a personal scala style preference.

Jack Davidson
  • 4,613
  • 2
  • 27
  • 31
  • I think, this violates semantics of "while" for predicates that are not "monotonic": `Seq(1,2,3,4).foldLeftWhile(0)(_ % 2 != 0) { case (b, a) => b }` returns `3`, but it should be `1` – Dima Dec 11 '18 at 19:01
  • Upvoted because it seems to answer my original question. Thanks @Dima – ntviet18 Dec 11 '18 at 20:38
  • @Dima I think you meant '(a, b) => b', since '(b, a) => b' always returns your starting value. Whether you apply my correction or not, that actually returns 0, since 0 % 2 != 0 returns false immediately, and your folding function never actually runs. Did you mean to pass a starting value other than 0? If you take my correction and pass 1, you get 2. Is this not what you would want or expect? – Jack Davidson Dec 11 '18 at 21:27
  • @JackDavidson yeah, I messed up with that example. The point is if the predicate turns false, your function keeps going ... so, if then it turns true again, it'll start accumulating again. Because `acc` is the only parameter, it doesn't matter as long as the predicate is pure, but it also depends on some outside state, that would break the semantics ... – Dima Dec 11 '18 at 21:36
  • @dima yes, that would be correct. if the predicate is not a pure function, there is no guarantee the function wouldn't resume processing inputs. I've updated with an alternative that reflects the new question and should permanently stop the first time the function returns false – Jack Davidson Dec 11 '18 at 23:23
2

What about this:

def foldLeftWhile[A, B](z: B, xs: Seq[A], p: B => Boolean)(op: (B, A) => B): B = {
  def go(acc: B, l: Seq[A]): B = l match {
    case h +: t => 
        val nacc = op(acc, h)
        if(p(nacc)) go(op(nacc, h), t) else nacc
    case _ => acc
  }
  go(z, xs)
}

val a = Seq(1,2,3,4,5,6)
val r = foldLeftWhile(0, a, (x: Int) => x <= 3)(_ + _)
println(s"$r")

Iterate recursively on the collection while the predicate is true, and then return the accumulator.

You cand try it on scalafiddle

Alejandro Alcalde
  • 5,990
  • 6
  • 39
  • 79
  • 1
    This makes an extra iteration: `p(result)` is always _false_ (expect, if it stayed true till the end) – Dima Dec 11 '18 at 18:59
  • Thank you, this answer to my original question. But I agree with @Dima that we should not advance further. So, I upvoted your answer. – ntviet18 Dec 11 '18 at 20:20
  • @Dima I do not understand, could you explain why is doing an extra iteration? I do not see it. Thanks! – Alejandro Alcalde Dec 12 '18 at 08:46
  • 1
    @ElBaulP last time `p(acc)` is true, you call `go` even though it is no false. You have to compute the result first, check condition, _then_ if it is true, pass it to `go`. – Dima Dec 12 '18 at 11:51
1

After a while I received a lot of good looking answers. So, I combined them to this single post

a very concise solution by @Dima

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    seq.toStream.scanLeft(z)(op).takeWhile(p).lastOption.getOrElse(z)
  }
}

by @ElBaulP (I modified a little bit to match comment by @Dima)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    @tailrec
    def foldLeftInternal(acc: B, seq: Seq[A]): B = seq match {
      case x :: _ =>
        val newAcc = op(acc, x)
        if (p(newAcc))
          foldLeftInternal(newAcc, seq.tail)
        else
          acc
      case _ => acc
    }

    foldLeftInternal(z, seq)
  }
}

Answer by me (involving side effects)

implicit class FoldLeftWhile[A](seq: Seq[A]) {

  def foldLeftWhile[B](z: B)(p: B => Boolean)(op: (B, A) => B): B = {
    var accumulator = z
    seq
      .map { e =>
        accumulator = op(accumulator, e)
        accumulator -> e
      }
      .takeWhile { case (acc, _) =>
        p(acc)
      }
      .lastOption
      .map { case (acc, _) =>
        acc
      }
      .getOrElse(z)
  }
}
ntviet18
  • 752
  • 1
  • 10
  • 26
1

Fist exemple: predicate for each element

First you can use inner tail recursive function

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case head :: tail if f(head) => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

Or short version

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case head :: tail if f(head) => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}

Then use it

val a = List(1, 2, 3, 4, 5, 6).foldLeftWhile(0, _ < 3)(_ + _) //a == 3

Second example: for accumulator value:

implicit class TravExt[A](seq: TraversableOnce[A]) {
  def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = {
    @tailrec
    def rec(trav: TraversableOnce[A], z: B): B = trav match {
      case _ if !f(z) => z
      case head :: tail => rec(tail, op(head, z))
      case _ => z
    }
    rec(seq, z)
  }
}

Or short version

implicit class TravExt[A](seq: TraversableOnce[A]) {
  @tailrec
  final def foldLeftWhile[B](z: B, f: A => Boolean)(op: (A, B) => B): B = seq match {
    case _ if !f(z) => z
    case head :: tail => tail.foldLeftWhile(op(head, z), f)(op)
    case _ => z
  }
}
Ivan Aristov
  • 152
  • 6
  • Your solution is very nice. Given we extends the `AnyVal` to make the value class. I think it would be quite efficient :) – ntviet18 Dec 11 '18 at 22:57
  • On my second thought, you should apply the predicate to the accumulator, but not the elements – ntviet18 Dec 11 '18 at 23:02
-1

Simply use a branch condition on the accumulator:

seq.foldLeft(0, _ < 3) { (acc, e) => if (acc < 3) acc + e else acc}

However you will run every entry of the sequence.

Nicolas Cailloux
  • 418
  • 4
  • 13