Desugaring the State (Transformer) Monad Part II

Example

This is the example code. We try to count all the occurences of s inside a string.
If a char is a ‘s’ then we increment the counter by 1. At the end we print the final value of the counter.

module Main where

import Control.Monad.State.Strict

type OccurrenceState = Int
type OccurrenceValue = Int

countOccurrence :: String -> State OccurenceState OccurenceValue
countOccurrence []     = do
    get
countOccurrence (x:xs) = do
    counter <- get
    if 's' == x then
      put $ counter + 1
    else
      put counter
    countOccurrence xs

main = print $ evalState (countOccurrence "s") 0

Replacing State with StateT

Definitions of State and StateT from Control.Monad.Trans.Strict.State

First we will replace State by StateT.

newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }

type State s = StateT s Identity

Rewriting

countOccurrence :: String -> State OccurenceState OccurenceValue

We see that:
s = OccurentState
m = Identity
a = OccurrenceValue

So we can rewrite the type of countOccurrence as:

countOccurrence :: String -> StateT OccurenceState Identity OccurenceValue

which yields this new example code

module Main where

import Control.Monad.State.Strict

type OccurrenceState = Int
type OccurrenceValue = Int

countOccurrence :: String -> StateT OccurenceState Identity OccurenceValue
countOccurrence []     = do
    get
countOccurrence (x:xs) = do
    counter <- get
    if 's' == x then
      put $ counter + 1
    else
      put counter
    countOccurrence xs


main = print $ evalState (countOccurrence "s") 0

Core output

The core output was obtained via stack ghci –ghci-options -dsuppress-all –ghci-options -ddump-simpl Main.hs.
I extracted the interesting parts.

startState = I# 0#

$dMonad_r2Zo = $fMonadStateT $fMonadIdentity

$dMonadState_r2Zp = $fMonadStatesStateT $fMonadIdentity

countOccurrence
  = \ ds_d2Z2 ->
      case ds_d2Z2 of {
        [] -> get $dMonadState_r2Zp;
        : x_a1GR xs_a1GS ->
          >>=
            $dMonad_r2Zo
            (get $dMonadState_r2Zp)
            (\ counter_a1GT ->
               >>
                 $dMonad_r2Zo
                 (case == $fEqChar (C# 's'#) x_a1GR of {
                    False -> put $dMonadState_r2Zp counter_a1GT;
                    True -> put $dMonadState_r2Zp (+ $fNumInt counter_a1GT (I# 1#))
                  })
                 (countOccurrence xs_a1GS))
      }

main
  = print
      $fShowInt
      (evalState (countOccurrence (unpackCString# "s"#)) (I# 0#))



What we can see from the Core output

All variables are postfixed by a _ and some kind of hash to make them uniquely identiable. So x_a1GR is our x.
The implicit dictionaries are now explit. A dictionary is prefixed by $d, for example: $dMonad_r2Zo.
We have two dictionaries containing other dictonaries with the implementations of the Monad Typeclass
and the Monad State Typeclass.

These dictionaries are $f prefixed, for example $fMonadStateT. This dictionary contains the
implementation of the Monad typeclass for StateT, so inside there are the implementations of
return and bind. The dictionaries are now explicitly passed as the first paremeter, for
example to >>=.
The # indicates unboxed C Elements for example here I# 1#, is the unboxed Integer 1.

Next Step: Inserting the definitions inside the dictionaries

What do the dictionaries contain?

-- contains return, bind
$dMonad = $fMonadStateT $fMonadIdentity
-- contains get, put
$dMonadState = $fMonadStatesStateT $fMonadIdentity

Getting the necessary implementations

We need source code from various packages.

  • Data.Functor.Identity
    newtype Identity a = Identity { runIdentity :: a }
        deriving (Eq, Ord, Data, Traversable, Generic, Generic1)
    
    instance Monad Identity where
        return   = Identity
        m >>= k  = k (runIdentity m)
    
  • Control.Monad.Trans.Identity
    newtype IdentityT f a = IdentityT { runIdentityT :: f a }
    
  • Control.Monad.Trans.State.Strict
    newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
    
    type State s = StateT s Identity
    
    -- Leaves state unchanged and sets result to the state
    get :: (Monad m) => StateT s m s
    get = state $ \ s -> (s, s)
    
    -- set result to () and set the state
    put :: (Monad m) => s -> StateT s m ()
    put s = state $ \ _ -> ((), s)
    
    state :: (Monad m)
          => (s -> (a, s))  -- ^pure state transformer
          -> StateT s m a   -- ^equivalent state-passing computation
    state f = StateT (return . f)
    
    
    evalState :: State s a  -- ^state-passing computation to execute
              -> s          -- ^initial value
              -> a          -- ^return value of the state computation
    evalState m s = fst (runState m s)
    
    
    runState :: State s a   -- ^state-passing computation to execute
             -> s           -- ^initial state
             -> (a, s)      -- ^return value and final state
    runState m = runIdentity . runStateT m
    
    instance (Monad m) => Monad (StateT s m) where
    --#if !(MIN_VERSION_base(4,8,0))
        -- leaves state unchanged and sets the result
        return a = StateT $ \ s -> return (a, s)
        {-# INLINE return #-}
    --#endif
        m >>= k  = StateT $ \ s -> do
            (a, s') <- runStateT m s
            runStateT (k a) s'
        {-# INLINE (>>=) #-}
        fail str = StateT $ \ _ -> fail str
        {-# INLINE fail #-}
    
  • Control.Monad.State.Class
    import qualified Control.Monad.Trans.State.Strict as Strict (StateT, get, put, state)
    
    instance Monad m => MonadState s (Strict.StateT s m) where
        get = Strict.get
        put = Strict.put
        state = Strict.state
    
    class Monad m => MonadState s m | m -> s where
        -- | Return the state from the internals of the monad.
        get :: m s
        get = state (\s -> (s, s))
    
        -- | Replace the state inside the monad.
        put :: s -> m ()
        put s = state (\_ -> ((), s))
    
        -- | Embed a simple state action into the monad.
        state :: (s -> (a, s)) -> m a
        state f = do
          s <- get
          let ~(a, s') = f s
          put s'
          return a
    
    

Doing the necessary expandings

  • Desugaring the do notation

    To desugar the do notation do the following:

    • eliminate do
    • Replacing <-:

      If there is a <- we look at the line containing the <- and the following line and replace it with >>=:
      For example:

      (a, s') <- runStateT m s
      runStateT (k a) s'
      

      Is the same as:

      runStateT m s >>= (\(a, s') -> runStateT (k a) s')
      
    • Inserting >> and replace it with >>=

      If we have two lines, and none of them contains the <- operator:
      For example:

      put s'
      return a
      

      This is aquivalent to:

      put s' >> return a
      

      m >> k is the same as m >>= ( -> k) leads to:

      put s' >>= (\_ -> return a)
      
  • $fMonadStateT
    • SourceCode
      -- from Control.Monad.Trans.State.Strict
      m >>= k  = StateT $ \ s -> do
          (a, s') <- runStateT m s
          runStateT (k a) s'
      
      return a = StateT $ \ s -> return (a, s)
      
    • Expanding return

      We will make our first expanding. The return definition of the
      outer monad (here the StateT monad) in a monad transformer depends on the
      return definition of the inner monad (in our case the Identitiy monad).

      return a = StateT $ \ s -> return (a, s)
      
      -- make dictionary explicit
      return a = StateT $ \ s -> return $dMonad (a , s)
      
      -- choose the right implementation dictionary
      return a = StateT $ \ s -> return $fMonadIdentity (a, s)
      
      -- inline decl:
      -- return = Identity
      return a = State $ \ s -> Identity (a, s)
      
    • Expanding bind
      m >>= k  = StateT $ \ s -> do
          (a, s') <- runStateT m s
          runStateT (k a) s'
      
      -- Make dictionary explicit and get rid of the do notation
      StateT $ \ s ->
      >>=
          $dMonad
          runStateT m s
          (\(a, s') ->
              runStateT (k a) s'
          )
      
      -- choose the right implementation dictionary
      StateT $ \ s ->
      >>=
          $fMonadIdentity
          runStateT m s
          (\(a, s') ->
              runStateT (k a) s'
          )
      
      -- applying of $fMonadIdentity = k (runIdentity m)
      m = runStateT m s
      k = (\(a, s') ->
          runStateT (k a) s'
      )
      
      -- final version
      StateT $ \ s ->
                  (\(a, s') -> runStateT (k a) s') $
                               runIdentity $ runStateT m s
      
      
  • $fMonadStatesStateT
    • SourceCode
      get :: (Monad m) => StateT s m s
      get = state $ \ s -> (s, s)
      
      put :: (Monad m) => s -> StateT s m ()
      put s = state $ \ _ -> ((), s)
      
      state :: (Monad m)
          => (s -> (a, s))  -- ^pure state transformer
          -> StateT s m a   -- ^equivalent state-passing computation
      state f = StateT (return . f)
      
      
      
    • Expanding

      If we look at the definiton from Identity

      newtype Identity a = Identity { runIdentity :: a }
      

      we can rewrite the ValueConstructor Identity (the one on the right side of the equation)
      as a function, which takes one parameter a and returns something with the type Identity

      Identity :: a -> Identity a
      
      get = state $ \ s -> (s, s)
      
      -- replace state by state f = StateT (return . f) from Control.Monad.Trans.State.Strict
      get = StateT (return . (\s -> (s, s)))
      
      -- inline from $fMonadIdentity return def: return = Identity
      get = StateT (Identity . (\s -> (s,s)))
      
      -- rewrite Value Constructor Identity as a function: \a -> Identity (a)
      get = StateT ((\a -> Identity a). (\s -> (s,s)))
      
      -- (.) = (b -> c) . (a -> b) = (a -> c)
      get = StateT (\s -> Identity (s, s)) 
      
      
      put s = state $ \ _ -> ((), s)
      
      -- replace state by state f = StateT (return . f) from Control.Monad.Trans.State.Strict
      put s = StateT (return . (\ _ -> ((), s)))
      
      -- inline from $fMonadIdentity return def: return = Identity
      put s = StateT (Identity . (\ _ -> ((), s)))
      
      -- rewrite Value Constructor Identity as a function: \a -> Identity (a)
      put s = StateT ((\a -> Identity a). (\ _ -> ((), s)))
      
      -- (.) = (b -> c) . (a -> b) = (a -> c)
      put s = StateT $ \_ -> Identity ((), s)
      
      
      
      
      

Results

$dMonad

$fMonadStateT

return a = State $ \ s -> Identity (a, s)

m >> k =
  StateT $ \ s ->
           (\(a, s') ->
              runStateT (k a) s') $ runIdentity $ runStateT m s

$fMonadIdentity

return = Identity

j >>= l  = l (runIdentity j)

$dMonadState

$fMonadStatesStateT

put s = StateT $ \_ -> Identity ((), s)

get = StateT (\s -> Identity (s, s)) 

Leave a Reply

Your email address will not be published. Required fields are marked *