149

Any pointers on how to solve efficiently the following function in Haskell, for large numbers (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

I've seen examples of memoization in Haskell to solve fibonacci numbers, which involved computing (lazily) all the fibonacci numbers up to the required n. But in this case, for a given n, we only need to compute very few intermediate results.

Thanks

Chetan
  • 46,743
  • 31
  • 106
  • 145
Angel de Vicente
  • 1,928
  • 3
  • 12
  • 16

8 Answers8

278

We can do this very efficiently by making a structure that we can index in sub-linear time.

But first,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Let's define f, but make it use 'open recursion' rather than call itself directly.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

You can get an unmemoized f by using fix f

This will let you test that f does what you mean for small values of f by calling, for example: fix f 123 = 144

We could memoize this by defining:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

That performs passably well, and replaces what was going to take O(n^3) time with something that memoizes the intermediate results.

But it still takes linear time just to index to find the memoized answer for mf. This means that results like:

*Main Data.List> faster_f 123801
248604

are tolerable, but the result doesn't scale much better than that. We can do better!

First, let's define an infinite tree:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

And then we'll define a way to index into it, so we can find a node with index n in O(log n) time instead:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... and we may find a tree full of natural numbers to be convenient so we don't have to fiddle around with those indices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Since we can index, you can just convert a tree into a list:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

You can check the work so far by verifying that toList nats gives you [0..]

Now,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

works just like with list above, but instead of taking linear time to find each node, can chase it down in logarithmic time.

The result is considerably faster:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

In fact it is so much faster that you can go through and replace Int with Integer above and get ridiculously large answers almost instantaneously

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358

For an out-of-the-box library that implements the tree based memoization, use MemoTrie:

$ stack repl --package MemoTrie
Prelude> import Data.MemoTrie
Prelude Data.MemoTrie> :set -XLambdaCase
Prelude Data.MemoTrie> :{
Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
Prelude Data.MemoTrie| fastest_f' = memo $ \case
Prelude Data.MemoTrie|   0 -> 0
Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
Prelude Data.MemoTrie| :}
Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358
Yang Bo
  • 3,586
  • 3
  • 22
  • 35
Edward Kmett
  • 29,632
  • 7
  • 85
  • 107
  • 5
    I tried this code and, interestingly, f_faster seemed to be slower than f. I guess those list references really slowed things down. The definition of nats and index seemed pretty mysterious to me, so I've added my own answer which might make things clearer. – Pitarou Jun 16 '12 at 04:41
  • @EdwardKmett I have spent hours learning / researching how this works and its very clever. But what i cant find is, why does the infinite list take so much more memory then the infinite tree? for example if you call "fastest_f 111111111" while watching ghci's memory usage you can see it uses next to nothing. But when you call faster_f 111111111 it uses around 1.5gb then ghci ends because I'm out of memory. I've tested their subsequent calls using ghci's :set +s and fastest_f does improve its speed to next to nothing and so does faster_f. So whats going on? – QuantumKarl Dec 08 '13 at 22:01
  • 6
    The infinite list case has to deal with a linked list 111111111 items long. The tree case is dealing with log n * the number of nodes reached. – Edward Kmett Dec 17 '13 at 07:15
  • 2
    i.e. the list version has to create thunks for all nodes in the list, whereas the tree version avoids creating a lot of them. – Tom Ellis Dec 17 '13 at 08:48
  • 8
    I know this is a rather old post, but shouldn't `f_tree` be defined in a `where` clause to avoid saving unneeded paths in the tree across calls? – dfeuer Aug 25 '14 at 17:22
  • 19
    The reason for stuffing it in a CAF was that you could get memoization across calls. If I had an expensive call I was memoizing, then I'd probably leave it in a CAF, hence the technique shown here. In a real application there is a trade-off between the benefits and costs of permanent memoization of course. Though, given the question was about how to achieve memoization, I think it'd be misleading to answer with a technique that deliberately avoids memoization across calls, and if nothing else then this commentary here will point folks to the fact that there are subtleties. ;) – Edward Kmett Aug 26 '14 at 07:47
  • 1
    Given a suitable definition of `tmap` (tree `map`), `nats` can be defined as `nats = Tree (tmap (succ . (*2)) nats) 0 (tmap ((*2) . succ) nats)`. – Brian McCutchon Sep 09 '17 at 21:12
  • Is there a name for the infinite tree used in this solution? It's like a Stern–Brocot tree but with naturals instead of rationals, but I cannot figure out how to search for it on google since I don't know its name. – Andrew Thaddeus Martin Oct 26 '17 at 11:46
  • The runtime of the original program is more like O(n^1.1). Calling it O(n^3) is technically true, but misleading. I believe the runtime is Θ((n^k)log(n)) where (1/2)^k+(1/3)^k+(1/4)^k=1. – Tesseract Jul 25 '18 at 20:02
  • @AndrewThaddeusMartin I made it up on the spot. In retrospect it is related to the usual 1-based implicit heap folks use, where the children are at positions 2k and 2k+1, but modified to start from a 0 base, so they wind up at 2k+1 and 2k+2. – Edward Kmett Oct 25 '18 at 06:57
  • 1
    Could someone comment on if/why the BangPatterns are necessary? – dainichi Apr 09 '19 at 13:15
  • @dainichi I think the extension is not _required_ but without it `go !n !s = Tree (go l s') n (go r s')` causes a [space leak](https://apfelmus.nfshost.com/articles/lazy-eval-intro.html). The bangs result in arguments of `go` to be reduced to weak head normal form, reducing memory usage significantly. – lindblandro Mar 07 '20 at 19:33
20

Edward's answer is such a wonderful gem that I've duplicated it and provided implementations of memoList and memoTree combinators that memoize a function in open-recursive form.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
Community
  • 1
  • 1
Tom Ellis
  • 9,224
  • 1
  • 29
  • 54
12

Not the most efficient way, but does memoize:

f = 0 : [ g n | n <- [1..] ]
    where g n = max n $ f!!(n `div` 2) + f!!(n `div` 3) + f!!(n `div` 4)

when requesting f !! 144, it is checked that f !! 143 exists, but its exact value is not calculated. It's still set as some unknown result of a calculation. The only exact values calculated are the ones needed.

So initially, as far as how much has been calculated, the program knows nothing.

f = .... 

When we make the request f !! 12, it starts doing some pattern matching:

f = 0 : g 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now it starts calculating

f !! 12 = g 12 = max 12 $ f!!6 + f!!4 + f!!3

This recursively makes another demand on f, so we calculate

f !! 6 = g 6 = max 6 $ f !! 3 + f !! 2 + f !! 1
f !! 3 = g 3 = max 3 $ f !! 1 + f !! 1 + f !! 0
f !! 1 = g 1 = max 1 $ f !! 0 + f !! 0 + f !! 0
f !! 0 = 0

Now we can trickle back up some

f !! 1 = g 1 = max 1 $ 0 + 0 + 0 = 1

Which means the program now knows:

f = 0 : 1 : g 2 : g 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Continuing to trickle up:

f !! 3 = g 3 = max 3 $ 1 + 1 + 0 = 3

Which means the program now knows:

f = 0 : 1 : g 2 : 3 : g 4 : g 5 : g 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now we continue with our calculation of f!!6:

f !! 6 = g 6 = max 6 $ 3 + f !! 2 + 1
f !! 2 = g 2 = max 2 $ f !! 1 + f !! 0 + f !! 0 = max 2 $ 1 + 0 + 0 = 2
f !! 6 = g 6 = max 6 $ 3 + 2 + 1 = 6

Which means the program now knows:

f = 0 : 1 : 2 : 3 : g 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : g 12 : ...

Now we continue with our calculation of f!!12:

f !! 12 = g 12 = max 12 $ 6 + f!!4 + 3
f !! 4 = g 4 = max 4 $ f !! 2 + f !! 1 + f !! 1 = max 4 $ 2 + 1 + 1 = 4
f !! 12 = g 12 = max 12 $ 6 + 4 + 3 = 13

Which means the program now knows:

f = 0 : 1 : 2 : 3 : 4 : g 5 : 6 : g 7 : g 8 : g 9 : g 10 : g 11 : 13 : ...

So the calculation is done fairly lazily. The program knows that some value for f !! 8 exists, that it's equal to g 8, but it has no idea what g 8 is.

rampion
  • 87,131
  • 49
  • 199
  • 315
  • Thank you for this one. How would you create and use a 2 dimensional solution space? Would that be a list of lists? and `g n m = (something with) f!!a!!b` – vikingsteve Jan 06 '14 at 08:21
  • 1
    Sure, you could. For a real solution, though, i'd probably use a memoization library, like [memocombinators](http://ocharles.org.uk/blog/posts/2013-12-08-24-days-of-hackage-data-memocombinators.html) – rampion Jan 07 '14 at 03:14
  • It's O(n^2) unfortunately. – Qumeric Oct 03 '16 at 10:24
9

This is an addendum to Edward Kmett's excellent answer.

When I tried his code, the definitions of nats and index seemed pretty mysterious, so I write an alternative version that I found easier to understand.

I define index and nats in terms of index' and nats'.

index' t n is defined over the range [1..]. (Recall that index t is defined over the range [0..].) It works searches the tree by treating n as a string of bits, and reading through the bits in reverse. If the bit is 1, it takes the right-hand branch. If the bit is 0, it takes the left-hand branch. It stops when it reaches the last bit (which must be a 1).

index' (Tree l m r) 1 = m
index' (Tree l m r) n = case n `divMod` 2 of
                          (n', 0) -> index' l n'
                          (n', 1) -> index' r n'

Just as nats is defined for index so that index nats n == n is always true, nats' is defined for index'.

nats' = Tree l 1 r
  where
    l = fmap (\n -> n*2)     nats'
    r = fmap (\n -> n*2 + 1) nats'
    nats' = Tree l 1 r

Now, nats and index are simply nats' and index' but with the values shifted by 1:

index t n = index' t (n+1)
nats = fmap (\n -> n-1) nats'
Pitarou
  • 2,211
  • 18
  • 25
  • Thanks. I'm memoizing a multivariate function, and this really helped me work out what index and nats were really doing. – Kittsil Mar 03 '17 at 05:55
9

As stated in Edward Kmett's answer, to speed things up, you need to cache costly computations and be able to access them quickly.

To keep the function non monadic, the solution of building an infinite lazy tree, with an appropriate way to index it (as shown in previous posts) fulfills that goal. If you give up the non-monadic nature of the function, you can use the standard associative containers available in Haskell in combination with “state-like” monads (like State or ST).

While the main drawback is that you get a non-monadic function, you do not have to index the structure yourself anymore, and can just use standard implementations of associative containers.

To do so, you first need to re-write you function to accept any kind of monad:

fm :: (Integral a, Monad m) => (a -> m a) -> a -> m a
fm _    0 = return 0
fm recf n = do
   recs <- mapM recf $ div n <$> [2, 3, 4]
   return $ max n (sum recs)

For your tests, you can still define a function that does no memoization using Data.Function.fix, although it is a bit more verbose:

noMemoF :: (Integral n) => n -> n
noMemoF = runIdentity . fix fm

You can then use State monad in combination with Data.Map to speed things up:

import qualified Data.Map.Strict as MS

withMemoStMap :: (Integral n) => n -> n
withMemoStMap n = evalState (fm recF n) MS.empty
   where
      recF i = do
         v <- MS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ MS.insert i v'
               return v'

With minor changes, you can adapt the code to works with Data.HashMap instead:

import qualified Data.HashMap.Strict as HMS

withMemoStHMap :: (Integral n, Hashable n) => n -> n
withMemoStHMap n = evalState (fm recF n) HMS.empty
   where
      recF i = do
         v <- HMS.lookup i <$> get
         case v of
            Just v' -> return v' 
            Nothing -> do
               v' <- fm recF i
               modify $ HMS.insert i v'
               return v'

Instead of persistent data structures, you may also try mutable data structures (like the Data.HashTable) in combination with the ST monad:

import qualified Data.HashTable.ST.Linear as MHM

withMemoMutMap :: (Integral n, Hashable n) => n -> n
withMemoMutMap n = runST $
   do ht <- MHM.new
      recF ht n
   where
      recF ht i = do
         k <- MHM.lookup ht i
         case k of
            Just k' -> return k'
            Nothing -> do 
               k' <- fm (recF ht) i
               MHM.insert ht i k'
               return k'

Compared to the implementation without any memoization, any of these implementation allows you, for huge inputs, to get results in micro-seconds instead of having to wait several seconds.

Using Criterion as benchmark, I could observe that the implementation with the Data.HashMap actually performed slightly better (around 20%) than that the Data.Map and Data.HashTable for which the timings were very similar.

I found the results of the benchmark a bit surprising. My initial feeling was that the HashTable would outperform the HashMap implementation because it is mutable. There might be some performance defect hidden in this last implementation.

Quentin
  • 91
  • 1
  • 2
  • 3
    GHC does a very good job of optimizing around immutable structures. Intuition from C doesn't always pan out. – John Tyree May 24 '15 at 16:47
4

A couple years later, I looked at this and realized there's a simple way to memoize this in linear time using zipWith and a helper function:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate has the handy property that dilate n xs !! i == xs !! div i n.

So, supposing we're given f(0), this simplifies the computation to

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Looking a lot like our original problem description, and giving a linear solution (sum $ take n fs will take O(n)).

rampion
  • 87,131
  • 49
  • 199
  • 315
  • 2
    so it's a generative (corecursive?), or dynamic programming, solution. Taking O(1) time per each generated value, like the usual Fibonacci is doing. Great! And EKMETT's solution is like the logarithmic big-Fibonacci, getting to the big numbers much faster, skipping over much of in-betweens. Is this about right? – Will Ness Aug 28 '18 at 15:48
  • or maybe it's closer the one for the Hamming numbers, with the three back-pointers into the sequence which is being produced, and the different speeds for each of them advancing along it. really pretty. – Will Ness Aug 28 '18 at 15:58
2

Yet another addendum to Edward Kmett's answer: a self-contained example:

data NatTrie v = NatTrie (NatTrie v) v (NatTrie v)

memo1 arg_to_index index_to_arg f = (\n -> index nats (arg_to_index n))
  where nats = go 0 1
        go i s = NatTrie (go (i+s) s') (f (index_to_arg i)) (go (i+s') s')
          where s' = 2*s
        index (NatTrie l v r) i
          | i <  0    = f (index_to_arg i)
          | i == 0    = v
          | otherwise = case (i-1) `divMod` 2 of
             (i',0) -> index l i'
             (i',1) -> index r i'

memoNat = memo1 id id 

Use it as follows to memoize a function with a single integer arg (e.g. fibonacci):

fib = memoNat f
  where f 0 = 0
        f 1 = 1
        f n = fib (n-1) + fib (n-2)

Only values for non-negative arguments will be cached.

To also cache values for negative arguments, use memoInt, defined as follows:

memoInt = memo1 arg_to_index index_to_arg
  where arg_to_index n
         | n < 0     = -2*n
         | otherwise =  2*n + 1
        index_to_arg i = case i `divMod` 2 of
           (n,0) -> -n
           (n,1) ->  n

To cache values for functions with two integer arguments use memoIntInt, defined as follows:

memoIntInt f = memoInt (\n -> memoInt (f n))
Neal Young
  • 159
  • 6
2

A solution without indexing, and not based on Edward KMETT's.

I factor out common subtrees to a common parent (f(n/4) is shared between f(n/2) and f(n/4), and f(n/6) is shared between f(2) and f(3)). By saving them as a single variable in the parent, the calculation of the subtree is done once.

data Tree a =
  Node {datum :: a, child2 :: Tree a, child3 :: Tree a}

f :: Int -> Int
f n = datum root
  where root = f' n Nothing Nothing


-- Pass in the arg
  -- and this node's lifted children (if any).
f' :: Integral a => a -> Maybe (Tree a) -> Maybe (Tree a)-> a
f' 0 _ _ = leaf
    where leaf = Node 0 leaf leaf
f' n m2 m3 = Node d c2 c3
  where
    d = if n < 12 then n
            else max n (d2 + d3 + d4)
    [n2,n3,n4,n6] = map (n `div`) [2,3,4,6]
    [d2,d3,d4,d6] = map datum [c2,c3,c4,c6]
    c2 = case m2 of    -- Check for a passed-in subtree before recursing.
      Just c2' -> c2'
      Nothing -> f' n2 Nothing (Just c6)
    c3 = case m3 of
      Just c3' -> c3'
      Nothing -> f' n3 (Just c6) Nothing
    c4 = child2 c2
    c6 = f' n6 Nothing Nothing

    main =
      print (f 123801)
      -- Should print 248604.

The code doesn't easily extend to a general memoization function (at least, I wouldn't know how to do it), and you really have to think out how subproblems overlap, but the strategy should work for general multiple non-integer parameters. (I thought it up for two string parameters.)

The memo is discarded after each calculation. (Again, I was thinking about two string parameters.)

I don't know if this is more efficient than the other answers. Each lookup is technically only one or two steps ("Look at your child or your child's child"), but there might be a lot of extra memory use.

Edit: This solution isn't correct yet. The sharing is incomplete.

Edit: It should be sharing subchildren properly now, but I realized that this problem has a lot of nontrivial sharing: n/2/2/2 and n/3/3 might be the same. The problem is not a good fit for my strategy.

leewz
  • 3,201
  • 1
  • 18
  • 38