59

I wanted to memoize this:

def fib(n: Int) = if(n <= 1) 1 else fib(n-1) + fib(n-2)
println(fib(100)) // times out

So I wrote this and this surprisingly compiles and works (I am surprised because fib references itself in its declaration):

case class Memo[A,B](f: A => B) extends (A => B) {
  private val cache = mutable.Map.empty[A, B]
  def apply(x: A) = cache getOrElseUpdate (x, f(x))
}

val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

println(fib(100))     // prints 100th fibonacci number instantly

But when I try to declare fib inside of a def, I get a compiler error:

def foo(n: Int) = {
  val fib: Memo[Int, BigInt] = Memo {
    case 0 => 0
    case 1 => 1
    case n => fib(n-1) + fib(n-2) 
  }
  fib(n)
} 

Above fails to compile error: forward reference extends over definition of value fib case n => fib(n-1) + fib(n-2)

Why does declaring the val fib inside a def fails but outside in the class/object scope works?

To clarify, why I might want to declare the recursive memoized function in the def scope - here is my solution to the subset sum problem:

/**
   * Subset sum algorithm - can we achieve sum t using elements from s?
   *
   * @param s set of integers
   * @param t target
   * @return true iff there exists a subset of s that sums to t
   */
  def subsetSum(s: Seq[Int], t: Int): Boolean = {
    val max = s.scanLeft(0)((sum, i) => (sum + i) max sum)  //max(i) =  largest sum achievable from first i elements
    val min = s.scanLeft(0)((sum, i) => (sum + i) min sum)  //min(i) = smallest sum achievable from first i elements

    val dp: Memo[(Int, Int), Boolean] = Memo {         // dp(i,x) = can we achieve x using the first i elements?
      case (_, 0) => true        // 0 can always be achieved using empty set
      case (0, _) => false       // if empty set, non-zero cannot be achieved
      case (i, x) if min(i) <= x && x <= max(i) => dp(i-1, x - s(i-1)) || dp(i-1, x)  // try with/without s(i-1)
      case _ => false            // outside range otherwise
    }

    dp(s.length, t)
  }
pathikrit
  • 32,469
  • 37
  • 142
  • 221
  • 3
    See my [blog post](http://michid.wordpress.com/2009/02/23/function_mem/) for another variant for memoization of recursive functions. – michid May 01 '13 at 21:10
  • 2
    Before I post anything to SO, I Google it and your blog post was the first result :) It is the "right" way to do this I agree - using the Y-combinator. But, I think using my style and exploiting `lazy val` looks cleaner than having 2 definitions (the recursive one and the Y-combined one) for each function. Looks how clean this [looks](1) [1]: https://github.com/pathikrit/scalgos/blob/master/src/main/scala/com/github/pathikrit/scalgos/Combinatorics.scala#L67 – pathikrit May 01 '13 at 21:39
  • I was confused by some of the terseness of the syntax in your problem above (specifically the case class's use of "extend (A => B)". I posted a question about it: http://stackoverflow.com/questions/19548103/in-scala-what-does-extends-a-b-on-a-case-class-mean – chaotic3quilibrium Oct 23 '13 at 20:23
  • Use this patten in caution with the concurrency issue brought by `Map`: http://stackoverflow.com/questions/6806123/does-using-val-with-hashtable-in-scala-resolve-concurrency-issues/6807324#6807324 – lcn Dec 08 '13 at 18:54
  • The question asked in the body and the accepted answer has nothing to do with the title of this question. Could you change the title? – user239558 Mar 15 '15 at 23:03

4 Answers4

54

I found a better way to memoize using Scala:

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {
  override def apply(key: I) = getOrElseUpdate(key, f(key))
}

Now you can write fibonacci as follows:

lazy val fib: Int => BigInt = memoize {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2)
}

Here's one with multiple arguments (the choose function):

lazy val c: ((Int, Int)) => BigInt = memoize {
  case (_, 0) => 1
  case (n, r) if r > n/2 => c(n, n - r)
  case (n, r) => c(n - 1, r - 1) + c(n - 1, r)
}

And here's the subset sum problem:

// is there a subset of s which has sum = t
def isSubsetSumAchievable(s: Vector[Int], t: Int) = {
  // f is (i, j) => Boolean i.e. can the first i elements of s add up to j
  lazy val f: ((Int, Int)) => Boolean = memoize {
    case (_, 0) => true        // 0 can always be achieved using empty list
    case (0, _) => false       // we can never achieve non-zero if we have empty list
    case (i, j) => 
      val k = i - 1            // try the kth element
      f(k, j - s(k)) || f(k, j)
  }
  f(s.length, t)
}

EDIT: As discussed below, here is a thread-safe version

def memoize[I, O](f: I => O): I => O = new mutable.HashMap[I, O]() {self =>
  override def apply(key: I) = self.synchronized(getOrElseUpdate(key, f(key)))
}
pathikrit
  • 32,469
  • 37
  • 142
  • 221
  • 2
    I don't think this (or most implementations I've seen based on `mutable.Map`) are thread-safe? But looks like nice syntax, if used in a single-threaded context. – Gary Coady May 01 '16 at 15:21
  • I'm not sure if the mutable HashMap implementation can actually crash and/or corrupt data in some way, or if the main issue is only missing updates; missing updates would probably be acceptable for most use cases. – Gary Coady May 01 '16 at 15:33
  • @Gary Coady: Its trivial to replace `HashMap` with `TrieMap` if you want concurrency – pathikrit May 04 '16 at 22:13
  • Sure, it's just something a user should be aware of, and sometimes solutions are copy/pasted from SO without considering issues like this ;-) – Gary Coady May 05 '16 at 15:19
  • 2
    I wonder if you can dead-lock even on a TrieMap. After all, the map is "recursively" accessed inside the `getOrElseUpdate` method. – VasiliNovikov May 07 '16 at 22:53
  • @VasyaNovikov: We can then just make the lock coarser by surrounding the `getOrElseUpdate` with `self.synchronized {getOrElseUpdate}` – pathikrit Aug 26 '16 at 16:13
  • `TrieMap` is `final`, so can't be subclassed like the above. Here's what I put together in order to use `TrieMap`: `def memoize[A, B](f: A => B): (A => B) = { val cache = collection.concurrent.TrieMap[A, B](); (a: A) => cache.getOrElseUpdate(a, f(a)) }`. – Jeff Klukas Jan 23 '17 at 19:44
  • @JeffKlukas: What is wrong with the `self.synchronized` version? – pathikrit Jan 24 '17 at 14:56
  • 2
    @pathikrit: I don't see anything wrong with `self.synchronized` version using mutable.HashMap. My comment here is mostly a clarification on the discussion of `TrieMap` in the comments above, since it turns out that it's not possible to simply sub in `TrieMap` to the given code. – Jeff Klukas Jan 25 '17 at 15:15
  • this not works for me, and it report `not found: value getOrElseUpdate` – luochen1990 Aug 06 '18 at 14:59
  • `memoize` is generic, which is nice, but as functions defined with `val`s must be monomorphic, this solution won't work for memoizing generic functions. Is there a workaround for this? – Paul Carey Nov 06 '20 at 11:00
22

Class/trait level val compiles to a combination of a method and a private variable. Hence a recursive definition is allowed.

Local vals on the other hand are just regular variables, and thus recursive definition is not allowed.

By the way, even if the def you defined worked, it wouldn't do what you expect. On every invocation of foo a new function object fib will be created and it will have its own backing map. What you should be doing instead is this (if you really want a def to be your public interface):

private val fib: Memo[Int, BigInt] = Memo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-1) + fib(n-2) 
}

def foo(n: Int) = {
  fib(n)
} 
missingfaktor
  • 90,905
  • 62
  • 285
  • 365
  • The 'foo' and 'fib' is just a simplification - in my case `foo` is the subset-sum problem and fib is the recursive memoization on the input set and thus I cannot simply extract my memoized function outside the method. Can you explain what you mean by "class level val compiles to combination of a method and a private variable" part? What are other differences I should be aware of between class and method `val`s? – pathikrit Apr 27 '13 at 22:50
  • i) What prevents you from extracting it outside of the method? ii) When you write `val x = N` at a class/trait level, what you get is `def x = _x` and a `private val _x = N`. You should find this explanation in any Scala book. I can't recall off the top of my head any other differences between field `val`s and local `val`s. – missingfaktor Apr 28 '13 at 06:44
  • 9
    A work around you can use even in the local scope: Make `fib` a `lazy val`. Then you should be able to recur on it in local scope as well. – missingfaktor Apr 28 '13 at 06:45
  • If it used mutable state and val. Does it mean that it is not thread-safe? – ses Apr 21 '14 at 20:13
  • @ses, unless that mutable piece of state has thread-safety guarantees. (You can be mutable AND thread-safe. It's just... more difficult.) – missingfaktor May 03 '14 at 19:38
  • More upvotes available if you can show how to make generic to n-arity functions. – user48956 Apr 29 '16 at 16:47
11

Scalaz has a solution for that, why not reuse it?

import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-2) + fib(n-1)
}

You can read more about memoization in Scalaz.

kxmh42
  • 3,121
  • 1
  • 25
  • 15
1

Mutable HashMap isn't thread safe. Also defining case statements separately for base conditions seems unnecessary special handling, rather Map can be loaded with initial values and passed to Memoizer. Following would be the signature of Memoizer where it accepts a memo(immutable Map) and formula and returns a recursive function.

Memoizer would look like

def memoize[I,O](memo: Map[I, O], formula: (I => O, I) => O): I => O

Now given a following Fibonacci formula,

def fib(f: Int => Int, n: Int) = f(n-1) + f(n-2)

fibonacci with Memoizer can be defined as

val fibonacci = memoize( Map(0 -> 0, 1 -> 1), fib)

where context agnostic general purpose Memoizer is defined as

    def memoize[I, O](map: Map[I, O], formula: (I => O, I) => O): I => O = {
        var memo = map
        def recur(n: I): O = {
          if( memo contains n) {
            memo(n) 
          } else {
            val result = formula(recur, n)
            memo += (n -> result)
            result
          }
        }
        recur
      }

Similarly, for factorial, a formula is

def fac(f: Int => Int, n: Int): Int = n * f(n-1)

and factorial with Memoizer is

val factorial = memoize( Map(0 -> 1, 1 -> 1), fac)

Inspiration: Memoization, Chapter 4 of Javascript good parts by Douglas Crockford

Boolean
  • 21
  • 2
  • > defining case statements separately for base conditions seems unnecessary special handling Really? Actually fib is one of the rare examples that has simple base cases. How would you solve the knapsack problem (https://github.com/pathikrit/scalgos/blob/master/src/main/scala/com/github/pathikrit/scalgos/DynamicProgramming.scala#L103) using this? – pathikrit Aug 27 '17 at 16:40
  • In case of fibonacci or anywhere where values are known upfront should be preloaded in map. It makes formula function read closer to its mathematical definition, IMO. If formula requires comparisons(case statements or if...else blocks) such as in solving knapsack problem, it is perfectly fine to use case statements. – Boolean Aug 28 '17 at 16:34