Desugaring the State (Transformer) Monad Part IV

Desugared Example Code

module Main where

import Control.Monad.State.Strict (StateT(..))
import Data.Functor.Identity (Identity(..))

type OccurrenceValue = Int
type OccurrenceState = Int


countOccurrence :: String -> StateT OccurenceState Identity OccurenceValue
countOccurrence [] = StateT $ \s -> Identity (s, s)
countOccurrence (x:xs) = StateT $ \s -> runStateT (k x xs s) s

m1 :: Char -> OccurrenceState -> StateT a Identity ()
m1 x counter = case ('s' == x) of
          True -> StateT $ \_ -> Identity ((), counter+1)
          False -> StateT $ \_ -> Identity ((), counter)

k :: Char -> String -> OccurrenceState -> StateT OccurenceState Identity OccurenceValue
k x xs counter = StateT $ \s->
    (\(a, s') -> runStateT (countOccurrence xs) s') $
                    runIdentity $ runStateT (m1 x counter) s  


main = print $ fst ((runIdentity . runStateT (countOccurrence "s")) 0)

Executing the example

Here we will execute the example to see with concrete values how the code executes.
What we can see is that for every call to countOccurrence we get a StateT back.
If we apply runStateT to the StateT, we get the function which is stored inside
StateT, then we can apply the “state” parameter to the function.

First Step: Executing (countOccurrence “s”)

countOccurrence ('s': []) = StateT $ \s -> runStateT (k 's' [] s) s

Reducing k

k 's' [] s = StateT $ \s2 ->
  (((a, s') -> runStateT (countOccurrence []) s') $
      runIdentity $ runStateT (m1 's' s) s2

-- 's' == x so choose the True Branch for m1
m1 's' s = StateT $ \_ -> Identity ((), s+1)

-- insert m1 inside k
k 's' [] s = StateT $ \s2 ->
  (((a, s') -> runStateT (countOccurrence []) s') $
      runIdentity $ runStateT (StateT $ \_ -> Identity ((), s+1)) s2

-- Calc countOccurrence []
countOccurrence [] = StateT $ \s3 -> Identity (s3, s3)

-- insert countOccurrence [] in k
k 's' [] s = StateT $ \s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
      runIdentity $ runStateT (StateT $ \_ -> Identity ((), s+1)))
         s2)







Insert k

countOccurrence ('s': []) = StateT $ \s -> runStateT (StateT $ \s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), s+1)) s2)
            ) s


main = print $ fst ((runIdentity . runStateT (countOccurrence "s")) 0)

Execute runStateT (countOccurrence “s”) 0

-- insert countOccurrence
runStateT (StateT $ \s -> runStateT (StateT $ \s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), s+1)) s2)
            ) s) 0

-- apply runStateT
\s -> runStateT (StateT $ \s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), s+1)) s2)
            ) s) 0

-- apply s=0
runStateT (StateT $ \s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), 1)) s2)
            ) 0) 

-- apply runStateT
\s2 ->
  (((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), 1)) s2)
            ) 0) 

-- apply s2=0
(((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ runStateT (StateT $ \_ -> Identity ((), 1)) 0

-- apply runStateT
(((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ (\_ -> Identity ((), 1)) 0

-- execute (\_ -> Identity ((), 1)) 0
(((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    runIdentity $ Identity ((), 1))

-- apply runIdentity
(((a, s') -> runStateT (StateT $ \s3 -> Identity (s3, s3)) s' $
    ((), 1))

-- apply ((), 1)
runStateT (StateT $ \s3 -> Identity (s3, s3)) 1 

-- apply runStateT
(\s3 -> Identity (s3, s3)) 1

-- apply 1 
Identity (1, 1)


Executing main

main = print $ fst (runIdentity . Identity (1,1))

main = print $ fst (1, 1)

main = print 1

The End

This is the end of the series, I hope it gave you some insights.

If you turn on the optimzation flag, and look at the core output, what we see then
is that the compiler will transform countOccurrence in a function where the state
is passed explicity as an argument to the function.

A simplified version of what the output looks like this:

countOccurrence l counter = case l of
   [] -> (counter, counter)
   x:xs -> case x of
      's' -> countOccurrence xs (counter+1)
      _ -> countOccurrence xs counter

Leave a Reply

Your email address will not be published. Required fields are marked *