6

I found an article:
Solving the 0-1 knapsack problem using continuation-passing style with memoization in F#

about knapsack problem implemented in F#. As I'm learning this language, I found this really interesting and tried to investigate this a bit. Here's the code I crafted:

open System
open System.IO 
open System.Collections.Generic

let parseToTuple (line : string) =
    let parsedLine = line.Split(' ') |> Array.filter(not << String.IsNullOrWhiteSpace)         |> Array.map Int32.Parse
    (parsedLine.[0], parsedLine.[1])

let memoize f =
    let cache = Dictionary<_, _>()
    fun x ->
        if cache.ContainsKey(x)
            then cache.[x]
        else
            let res = f x
            cache.[x] <- res
            res

type Item =
    {
        Value : int
        Size  : int
    }  

type ContinuationBuilder() = 
    member b.Bind(x, f) = fun k -> x (fun x -> f x k)
    member b.Return x = fun k ->  k x
    member b.ReturnFrom x = x

let cont = ContinuationBuilder()

let set1 =
    [
        (4, 11)
        (8, 4)
        (10, 5)
        (15, 8)
        (4, 3)
    ]

let set2 =
    [
        (50, 341045); (1906, 4912); (41516, 99732); (23527, 56554); (559, 1818); (45136, 108372); (2625, 6750); (492, 1484)
        (1086, 3072); (5516, 13532); (4875, 12050); (7570, 18440); (4436, 10972); (620, 1940); (50897, 122094); (2129, 5558)
        (4265, 10630); (706, 2112); (2721, 6942); (16494, 39888); (29688, 71276); (3383, 8466); (2181, 5662); (96601, 231302)
        (1795, 4690); (7512, 18324); (1242, 3384); (2889, 7278); (2133, 5566); (103, 706); (4446, 10992); (11326, 27552)
        (3024, 7548); (217, 934); (13269, 32038); (281, 1062); (77174, 184848); (952, 2604); (15572, 37644); (566, 1832)
        (4103, 10306); (313, 1126); (14393, 34886); (1313, 3526); (348, 1196); (419, 1338); (246, 992); (445, 1390)
        (23552, 56804); (23552, 56804); (67, 634)
    ]

[<EntryPoint>] 
let main args =
    // prepare list of items from a file args.[0]
    let header, items = set1
                        |> function
                           | h::t -> h, t
                           | _    -> raise (Exception("Wrong data format"))

    let N, K = header
    printfn "N = %d, K = %d" N K
    let items = List.map (fun x -> {Value = fst x ; Size = snd x}) items |> Array.ofList

    let rec combinations =
        let innerSolver key =
            cont
                {
                    match key with
                    | (i, k) when i = 0 || k = 0        -> return 0
                    | (i, k) when items.[i-1].Size > k  -> return! combinations (i-1, k)
                    | (i, k)                            -> let item = items.[i-1]
                                                           let! v1 = combinations (i-1, k)
                                                           let! beforeItem = combinations (i-1, k-item.Size)
                                                           let v2 = beforeItem + item.Value
                                                           return max v1 v2
                }
        memoize innerSolver

    let res = combinations (N, K) id
    printfn "%d" res
    0

However, the problem with this implementation is that it's veeeery slow (in practice I'm unable to solve problem with 50 items and capacity of ~300000, which gets solved by my naive implementation in C# in less than 1s).

Could you tell me if I made a mistake somewhere? Or maybe the implementation is correct and this is simply the inefficient way of solving this problem.

LA.27
  • 1,888
  • 19
  • 35
  • 3
    The standard F# performance comments: probably avoid the continuation. Avoid lists, use Arrays. Try a line by line translation of the C# and compare. Also, be careful with comparison operators that can be slow and check your compiler options. – John Palmer Jul 03 '13 at 11:01
  • Considering the minimal size of your test, I'm going to guess that there is a logic error in your code somewhere. Have you verified your code with, say, 5 items? – N_A Jul 03 '13 at 14:32
  • Have you profiled it? – Daniel Jul 03 '13 at 14:47
  • Yes - I ran this on an instance with 4 items and the knapsack's size of 11 -> produced a correct solution without a problem. About profiling - not really. I just ran this on instances of 20, 30 items and there was no problem. Then I used an instance with 50 items and... there was no response for ~30 minutes. As I said - my C# solution solved even greater problems. In C# I didn't use recursive calls. Might that be the reason ? I read that tail-recursion gets translated into while-loop in F#, so both implementation should have similar performance. – LA.27 Jul 03 '13 at 14:54
  • 1. Make sure you're testing a release build. 2. The way you use `memoize` inside of `combinations` looks quite suspect to me. – ildjarn Jul 03 '13 at 15:06
  • This case `| (i, k)` doesn't look tail recursive to me. – N_A Jul 03 '13 at 15:55
  • 1
    @ ildjarn 1.) I double-checked - I'm testing Release build. Moreover, I learnt that tail-recursion is not supported in Debug by default, so if I ran it on larger instance, I would get StackOverflow exception. 2.) What exactly don't you like in this ? I followed the instructions in the article, however I've done so for the first time, so there might be mistakes. @ mydogisbox That's because recursion is hidden inside the computation expression - notice let! statements and Bind method. – LA.27 Jul 03 '13 at 16:59
  • At a glance, it appears to me that each time you call or pass `combinations` you get a different `memoization` instance, essentially making the effort of memoization unused overhead. – ildjarn Jul 03 '13 at 17:45
  • I found an article about memoization here: http://en.wikibooks.org/wiki/F_Sharp_Programming/Caching And noticed this is the way to do so: mainFunction -> ... memoize innerFun. – LA.27 Jul 03 '13 at 22:47
  • 1
    Can you perhaps remove the input parsing from your question and instead provide some concrete inputs so that people who want to profile it and give you some feedback about the performance can run your code straight away? (Thanks!) – Tomas Petricek Jul 03 '13 at 23:40
  • Done - I attached two instances. – LA.27 Jul 04 '13 at 00:07
  • In general that is the way to do so, but if `innerFun` recursively calls `mainFunction` then each recursion calls `memoize` separately and ends up with its own, _separate_ cache (again, effectively making said cache unused overhead). – ildjarn Jul 04 '13 at 01:09

2 Answers2

7

When you naively apply a generic memoizer like this, and use continuation passing, the values in your memoization cache are continuations, not regular "final" results. Thus, when you get a cache hit, you aren't getting back a finalized result, you are getting back some function which promises to compute a result when you invoke it. This invocation might be expensive, might invoke various other continuations, might ultimately hit the memoization cache again itself, etc.

Effectively memoizing continuation-passing functions such that a) the caching works to full effect and b) the function remains tail-recursive is quite difficult. Read this discussion and come back when you fully understand it all. ;-)

The author of the blog post you linked is using a more sophisticated, less generic memoizer which is specially fitted to the problem. Admittedly, I don't fully grok it yet (code on the blog is incomplete/broken, so hard to try it out), but I think the gist of it is that it "forces" the chain of continuations before caching the final integer result.

To illustrate the point, here's a quick refactor of your code which is fully self-contained and traces out relevant info:

open System
open System.Collections.Generic

let mutable cacheHits = 0
let mutable cacheMisses = 0

let memoize f =
    let cache = Dictionary<_, _>()
    fun x ->
        match cache.TryGetValue(x) with
        | (true, v) -> 
            cacheHits <- cacheHits + 1
            printfn "Hit for %A - Result is %A" x v
            v
        | _ ->
            cacheMisses <- cacheMisses + 1
            printfn "Miss for %A" x
            let res = f x
            cache.[x] <- res
            res

type Item = { Value : int; Size  : int }  

type ContinuationBuilder() = 
    member b.Bind(x, f) = fun k -> x (fun x -> f x k)
    member b.Return x = fun k ->  k x
    member b.ReturnFrom x = x

let cont = ContinuationBuilder()

let genItems n = 
   [| for i = 1 to n do
         let size = i % 5
         let value = (size * i)
         yield { Value = value; Size = size }
   |]

let N, K = (5, 100)
printfn "N = %d, K = %d" N K

let items = genItems N

let rec combinations_cont =
    memoize (
     fun key ->
       cont {
                match key with
                | (0, _) | (_, 0)                   -> return 0
                | (i, k) when items.[i-1].Size > k  -> return! combinations_cont (i - 1, k) 
                | (i, k)                            -> let item = items.[i-1]
                                                       let! v1 = combinations_cont (i-1, k)
                                                       let! beforeItem = combinations_cont (i-1, k - item.Size)
                                                       let v2 = beforeItem + item.Value
                                                       return max v1 v2
        }
    )

let res = combinations_cont (N, K) id
printfn "Answer: %d" res
printfn "Memo hits: %d" cacheHits
printfn "Memo misses: %d" cacheMisses
printfn ""

let rec combinations_plain =
    memoize (
     fun key ->
                match key with
                | (i, k) when i = 0 || k = 0        -> 0
                | (i, k) when items.[i-1].Size > k  -> combinations_plain (i-1, k) 
                | (i, k)                            -> let item = items.[i-1]
                                                       let v1 = combinations_plain (i-1, k)
                                                       let beforeItem = combinations_plain (i-1, k-item.Size)
                                                       let v2 = beforeItem + item.Value
                                                       max v1 v2
    )

cacheHits <- 0
cacheMisses <- 0

let res2 = combinations_plain (N, K)
printfn "Answer: %d" res2
printfn "Memo hits: %d" cacheHits
printfn "Memo misses: %d" cacheMisses

As you can see, the CPS version is caching continuations (not integers), and there are is a lot of extra activity going on toward the end as the continuations are invoked.

If you boost the problem size to let (N, K) = (20, 100) (and remove the printfn statements in the memoizer), you will see that the CPS version ends up doing over 1 million cache lookups, compared to plain version doing only a few hundred.

Community
  • 1
  • 1
latkin
  • 16,402
  • 1
  • 47
  • 62
6

From running this code in FSI:

open System
open System.Diagnostics
open System.Collections.Generic

let time f =
    System.GC.Collect()
    let sw = Stopwatch.StartNew()
    let r = f()
    sw.Stop()
    printfn "Took: %f" sw.Elapsed.TotalMilliseconds
    r

let mutable cacheHits = 0
let mutable cacheMisses = 0

let memoize f =
    let cache = Dictionary<_, _>()
    fun x ->
        match cache.TryGetValue(x) with
        | (true, v) -> 
            cacheHits <- cacheHits + 1
            //printfn "Hit for %A - Result is %A" x v
            v
        | _ ->
            cacheMisses <- cacheMisses + 1
            //printfn "Miss for %A" x
            let res = f x
            cache.[x] <- res
            res

type Item = { Value : int; Size  : int }  

type ContinuationBuilder() = 
    member b.Bind(x, f) = fun k -> x (fun x -> f x k)
    member b.Return x = fun k ->  k x
    member b.ReturnFrom x = x

let cont = ContinuationBuilder()

let genItems n = 
    [| for i = 1 to n do
            let size = i % 5
            let value = (size * i)
            yield { Value = value; Size = size }
    |]

let N, K = (80, 400)
printfn "N = %d, K = %d" N K

let items = genItems N

//let rec combinations_cont =
//    memoize (
//     fun key ->
//       cont {
//                match key with
//                | (0, _) | (_, 0)                   -> return 0
//                | (i, k) when items.[i-1].Size > k  -> return! combinations_cont (i - 1, k) 
//                | (i, k)                            -> let item = items.[i-1]
//                                                       let! v1 = combinations_cont (i-1, k)
//                                                       let! beforeItem = combinations_cont (i-1, k - item.Size)
//                                                       let v2 = beforeItem + item.Value
//                                                       return max v1 v2
//        }
//    )
//
//
//cacheHits <- 0
//cacheMisses <- 0

//let res = time(fun () -> combinations_cont (N, K) id)
//printfn "Answer: %d" res
//printfn "Memo hits: %d" cacheHits
//printfn "Memo misses: %d" cacheMisses
//printfn ""

let rec combinations_plain =
    memoize (
        fun key ->
                match key with
                | (i, k) when i = 0 || k = 0        -> 0
                | (i, k) when items.[i-1].Size > k  -> combinations_plain (i-1, k) 
                | (i, k)                            -> let item = items.[i-1]
                                                       let v1 = combinations_plain (i-1, k)
                                                       let beforeItem = combinations_plain (i-1, k-item.Size)
                                                       let v2 = beforeItem + item.Value
                                                       max v1 v2
    )

cacheHits <- 0
cacheMisses <- 0

printfn "combinations_plain"
let res2 = time (fun () -> combinations_plain (N, K))
printfn "Answer: %d" res2
printfn "Memo hits: %d" cacheHits
printfn "Memo misses: %d" cacheMisses
printfn ""

let recursivelyMemoize f =
    let cache = Dictionary<_, _>()
    let rec memoizeAux x =
        match cache.TryGetValue(x) with
        | (true, v) -> 
            cacheHits <- cacheHits + 1
            //printfn "Hit for %A - Result is %A" x v
            v
        | _ ->
            cacheMisses <- cacheMisses + 1
            //printfn "Miss for %A" x
            let res = f memoizeAux x
            cache.[x] <- res
            res
    memoizeAux

let combinations_plain2 =
    let combinations_plain2Aux combinations_plain2Aux key =
                match key with
                | (i, k) when i = 0 || k = 0        -> 0
                | (i, k) when items.[i-1].Size > k  -> combinations_plain2Aux (i-1, k) 
                | (i, k)                            -> let item = items.[i-1]
                                                       let v1 = combinations_plain2Aux (i-1, k)
                                                       let beforeItem = combinations_plain2Aux (i-1, k-item.Size)
                                                       let v2 = beforeItem + item.Value
                                                       max v1 v2
    let memoized = recursivelyMemoize combinations_plain2Aux
    fun x -> memoized x

cacheHits <- 0
cacheMisses <- 0

printfn "combinations_plain2"
let res3 = time (fun () -> combinations_plain2 (N, K))
printfn "Answer: %d" res3
printfn "Memo hits: %d" cacheHits
printfn "Memo misses: %d" cacheMisses
printfn ""

let recursivelyMemoizeCont f =
    let cache = Dictionary HashIdentity.Structural
    let rec memoizeAux x k =
        match cache.TryGetValue(x) with
        | (true, v) -> 
            cacheHits <- cacheHits + 1
            //printfn "Hit for %A - Result is %A" x v
            k v
        | _ ->
            cacheMisses <- cacheMisses + 1
            //printfn "Miss for %A" x
            f memoizeAux x (fun y ->
                cache.[x] <- y
                k y)
    memoizeAux

let combinations_cont2 =
    let combinations_cont2Aux combinations_cont2Aux key =
        cont {
                match key with
                | (0, _) | (_, 0)                   -> return 0
                | (i, k) when items.[i-1].Size > k  -> return! combinations_cont2Aux (i - 1, k) 
                | (i, k)                            -> let item = items.[i-1]
                                                       let! v1 = combinations_cont2Aux (i-1, k)
                                                       let! beforeItem = combinations_cont2Aux (i-1, k - item.Size)
                                                       let v2 = beforeItem + item.Value
                                                       return max v1 v2
        }
    let memoized = recursivelyMemoizeCont combinations_cont2Aux
    fun x -> memoized x id

cacheHits <- 0
cacheMisses <- 0

printfn "combinations_cont2"
let res4 = time (fun () -> combinations_cont2 (N, K))
printfn "Answer: %d" res4
printfn "Memo hits: %d" cacheHits
printfn "Memo misses: %d" cacheMisses
printfn ""

I get these results:

N = 80, K = 400
combinations_plain
Took: 7.191000
Answer: 6480
Memo hits: 6231
Memo misses: 6552

combinations_plain2
Took: 6.310800
Answer: 6480
Memo hits: 6231
Memo misses: 6552

combinations_cont2
Took: 17.021200
Answer: 6480
Memo hits: 6231
Memo misses: 6552
  • combinations_plain is from latkin's answer.
  • combinations_plain2 exposes the recursive memoization step explicitly.
  • combinations_cont2 adapts the recursive memoization function into one that memoizes the continuation results.
  • combinations_cont2 works by intercepting the result in the continuation before passing it on to the actual continuation. Subsequent calls on the same key provide a continuation and this continuation is fed the answer we intercepted originally.

This demonstrates that we are able to:

  1. Memoize using continuation passing style.
  2. Achieve similar (ish) performance characteristics to the vanilla memoized version.

I hope this clears things up a little. Sorry, my blog code snippet was incomplete (I think I might have lost it when reformatting recently).

ZachBray
  • 353
  • 2
  • 7