is there no easier way to roll your own [memoization?]
Easier than what? A state monad is really easy and if you are used to thinking imperatively then it should also be intuitive.
The full, inlined, version that uses a vector instead of the list is:
{-# LANGUAGE MultiWayIf #-}
import Control.Monad.Trans.State as S
import Data.Vector as V
import Data.Map.Strict as M
goGood :: [Int] -> Int -> Int -> Int
goGood xs t0 i0 =
let v = V.fromList xs
in evalState (explicitMemo v t0 i0) mempty
where
explicitMemo :: Vector Int -> Int -> Int -> State (Map (Int,Int) Int) Int
explicitMemo v t i = do
m <- M.lookup (t,i) <$> get
case m of
Nothing ->
do res <- if | t == 0 -> pure 1
| t < 0 -> pure 0
| i < 0 -> pure 0
| t < (v V.! i) -> explicitMemo v t (i-1)
| otherwise -> (+) <$> explicitMemo v (t - (v V.! i)) (i-1) <*> explicitMemo v t
(i-1)
S.modify (M.insert (t,i) res)
pure res
Just r -> pure r
That is, we look up in a map if we've already computed the result. If so, return the result. If not, compute and store the result before returning it.
We can clean this up a lot with just a couple helper functions:
prettyMemo :: Vector Int -> Int -> Int -> State (Map (Int,Int) Int) Int
prettyMemo v t i = cachedReturn =<< cachedEval (
if | t == 0 -> pure 1
| t < 0 -> pure 0
| i < 0 -> pure 0
| t < (v V.! i) -> prettyMemo v t (i-1)
| otherwise ->
(+) <$> prettyMemo v (t - (v V.! i)) (i-1)
<*> prettyMemo v t (i-1)
)
where
key = (t,i)
-- Lookup value in cache and return it
cachedReturn res = S.modify (M.insert key res) >> pure res
-- Use cached value or run the operation
cachedEval oper = maybe oper pure =<< (M.lookup key <$> get)
Now our map lookup and map update are in some simple (to the experienced Haskell developer) helper functions that wrap the entire computation. A small difference here is we update the map regardless of if the computation was cached at some minor computational cost.
We can make this even cleaner by dropping the monad (see the linked related questions). There is a popular package (MemoTrie) that handles the guts for you:
memoTrieVersion :: [Int] -> Int -> Int -> Int
memoTrieVersion xs = go
where
v = V.fromList xs
go t i | t == 0 = 1
| t < 0 = 0
| i < 0 = 0
| t < v V.! i = memo2 go t (i-1)
| otherwise = memo2 go (t - (v V.! i)) (i-1) + memo2 go t (i-1)
If you like the monadic style you could always use the monad-memo
package.
EDIT: A mostly-direct translation of your Python code to Haskell shows an important difference is the immutability of the variables. In your otherwise
(or else
) case you use go
twice and implicitly one invocation will update the cache (m
) that the second call uses, thus saving computation in a memoizing manner. In Haskell if you're avoiding monads and lazy evaluation to recursively define a vector (which can be quite powerful) then the simplest solution left is to explicitly pass your map (dictionary) around:
import Data.Vector as V
import Data.Map as M
goWrapped :: Vector Int -> Int -> Int -> Int
goWrapped xxs t i = fst $ goPythonVersion xxs t i mempty
goPythonVersion :: Vector Int -> Int -> Int -> Map (Int,Int) Int -> (Int,Map (Int,Int) Int)
goPythonVersion xxs t i m =
let k = (t,i)
in case M.lookup k m of -- if k in m:
Just r -> (r,m) -- return m[k]
Nothing ->
let (res,m') | t == 0 = (1,m)
| t < 0 = (0,m)
| i < 0 = (0,m)
| t < xxs V.! i = goPythonVersion xxs t (i-1) m
| otherwise =
let (r1,m1) = goPythonVersion xxs (t - (xxs V.! i)) (i-1) m
(r2,m2) = goPythonVersion xxs t (i-1) m1
in (r1 + r2, m2)
in (res, M.insert k res m')
And while this version is a decent translation of the Python I'd rather see a more idiomatic solution such as the below. Notice we bind a variable to the resulting computation (named "computed" for the Int and the updated map) but thanks to lazy evaluation not much work is done unless the cache doesn't yield a result.
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TupleSections #-}
goMoreIdiomatic:: Vector Int -> Int -> Int -> Map (Int,Int) Int -> (Int,Map (Int,Int) Int)
goMoreIdiomatic xxs t i m =
let cached = M.lookup (t,i) m
~(comp, M.insert (t,i) comp -> m')
| t == 0 = (1,m)
| t < 0 = (0,m)
| i < 0 = (0,m)
| t < xxs V.! i = goPythonVersion xxs t (i-1) m
| otherwise =
let (r1,m1) = goPythonVersion xxs (t - (xxs V.! i)) (i-1) m
(r2,m2) = goPythonVersion xxs t (i-1) m1
in (r1 + r2, m2)
in maybe (comp,m') (,m) cached