You are computing the fix-point of your state transform.
But we can't use fix
since we are in a monadic context.
So let's use a monadic fix-point combinator instead.
Enter mfix
:
import Control.Monad (unless)
import Control.Monad.State (MonadState, StateT, get, put)
import Control.Monad.Fix (mfix)
import Control.Monad.IO.Class (liftIO)
untilStable :: MonadState s m => (s -> s -> Bool) -> m a -> m ()
untilStable p = mfix $ \f st -> p <$> get <* st <*> get >>= (`unless` f)
I also took the liberty of generalizing your function to so that you can provide a user supplied binary predicate.
Using ghci runState (untilStable (==) $ modify (+1)) 2
will never terminate.
But with:
comp :: StateT Int IO ()
comp = do
s1 <- (+1) <$> get
liftIO $ print s1
let s2 = if s1 >= 3 then 3 else s1
put s2
You get:
> runStateT (untilStable (==) comp) 0
1
2
3
4
((),3)
This untilStable
can be generalized further into:
untilStable :: MonadState s m => (s -> s -> Bool) -> m a -> m a
untilStable p = mfix $ \f st -> do
before <- get
a <- st
after <- get
if p before after then pure a else f
Now we have freed up what types the computations can result in.
Fix you want to implement idempotently
with fix
, you can do it like so:
import Data.Function (fix)
idempotently :: Eq a => (a -> a) -> a -> a
idempotently = fix $ \i f a ->
let a' = f a
in if a' == a then a else i f a'