{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Knot-tying monad transformer for recursive graph building.
--
-- Some operations, such as building a graph, are inherently self-recursive;
-- consider the following graph:
--
--    > a -> b
--    > b -> a
--
-- To construct in Haskell, we might want to use the following type:
--
--    > data Node = Node
--    >   { nodeName :: Text
--    >   , nodeNeighbours :: [Node]
--    >   }
--
-- To construct our trivial graph, we need @a@ to know about @b@ and @b@ to know
-- about @a@: this is fine as long as we can build them both at the same time:
--
--    > graph = [nodeA, nodeB]
--    >   where
--    >     nodeA = Node "a" [nodeB]
--    >     nodeB = Node "b" [nodeA]
--
-- But this falls apart as soon as building the nodes becomes more complicated;
-- for instance, if it becomes monadic. This causes an infinite recursion:
--
--    > graph = do
--    >   a <- buildA
--    >   b <- buildB
--    >   pure [a,b]
--    >   where
--    >     buildA = do
--    >       b <- buildB
--    >       pure $ Node "a" [b]
--    >     buildB = do
--    >       a <- buildA
--    >       pure $ Node "b" [a]
--
-- The reason why the non-monadic version works is laziness; and there is a way
-- to retrieve this laziness in a monadic context: it's what 'MonadFix' is for.
-- (https://wiki.haskell.org/MonadFix)
--
-- However, 'MonadFix' is both powerful and unintuitive; the goal of this module
-- is to use its power, but to give it a more restricted interface, to make it
-- easier to use. Using 'CircularT', the graph above can be built monadically
-- like so:
--
--    > graph = runCircularT do
--    >   a <- buildA
--    >   b <- buildB
--    >   pure [a,b]
--    >   where
--    >     buildA = withCircular "a" do
--    >       b <- buildB
--    >       pure $ Node "a" [b]
--    >     buildB = withCircular "b" do
--    >       a <- buildA
--    >       pure $ Node "b" [a]
--
-- It allows each part of a recursive process to be given a name (the type of
-- which is of the user's choosing), and it automatically breaks cycles. The
-- only caveat is that we cannot violate temporal causality: if we attempt to
-- make a cache-building decision based on the value obtained from the cache,
-- then no amount of laziness can save us:
--
--    > broken = runCircularT go
--    >   where
--    >     go = withCircular () do
--    >       x <- go
--    >       pure $ if odd x then 1 else 0
--
-- `CircularT` is somewhat similar to `TardisT` from @Control.Monad.Tardis@ and
-- `SchemaT` from @Hasura.GraphQL.Parser.Monad@, but simpler than both.
module Control.Monad.Circular
  ( CircularT,
    runCircularT,
    withCircular,
  )
where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer.Strict
import Data.HashMap.Lazy (HashMap)
import Data.HashMap.Lazy qualified as Map
import Data.Hashable (Hashable)
import Prelude

-- | CircularT is implemented as a state monad containing a lazy HashMap.
--
-- We use this state to both determine wether we have already encountered a
-- given key and to track the associated result. We use laziness and MonadFix to
-- tie the knot for us (see 'withCircular').
--
-- - type @k@ is the type of cache key, to which a given action is associated.
-- - type @v@ is the values we wish to cache in our process.
-- - type @m@ is the underlying monad on which this transformer operates.
-- - type @a@ is the result of the computation
newtype CircularT k v m a = CircularT (StateT (HashMap k v) m a)
  deriving
    ( (forall a b. (a -> b) -> CircularT k v m a -> CircularT k v m b)
-> (forall a b. a -> CircularT k v m b -> CircularT k v m a)
-> Functor (CircularT k v m)
forall a b. a -> CircularT k v m b -> CircularT k v m a
forall a b. (a -> b) -> CircularT k v m a -> CircularT k v m b
forall k v (m :: * -> *) a b.
Functor m =>
a -> CircularT k v m b -> CircularT k v m a
forall k v (m :: * -> *) a b.
Functor m =>
(a -> b) -> CircularT k v m a -> CircularT k v m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall k v (m :: * -> *) a b.
Functor m =>
(a -> b) -> CircularT k v m a -> CircularT k v m b
fmap :: forall a b. (a -> b) -> CircularT k v m a -> CircularT k v m b
$c<$ :: forall k v (m :: * -> *) a b.
Functor m =>
a -> CircularT k v m b -> CircularT k v m a
<$ :: forall a b. a -> CircularT k v m b -> CircularT k v m a
Functor,
      Functor (CircularT k v m)
Functor (CircularT k v m)
-> (forall a. a -> CircularT k v m a)
-> (forall a b.
    CircularT k v m (a -> b) -> CircularT k v m a -> CircularT k v m b)
-> (forall a b c.
    (a -> b -> c)
    -> CircularT k v m a -> CircularT k v m b -> CircularT k v m c)
-> (forall a b.
    CircularT k v m a -> CircularT k v m b -> CircularT k v m b)
-> (forall a b.
    CircularT k v m a -> CircularT k v m b -> CircularT k v m a)
-> Applicative (CircularT k v m)
forall a. a -> CircularT k v m a
forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m a
forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
forall a b.
CircularT k v m (a -> b) -> CircularT k v m a -> CircularT k v m b
forall a b c.
(a -> b -> c)
-> CircularT k v m a -> CircularT k v m b -> CircularT k v m c
forall {k} {v} {m :: * -> *}. Monad m => Functor (CircularT k v m)
forall k v (m :: * -> *) a. Monad m => a -> CircularT k v m a
forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m a
forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m (a -> b) -> CircularT k v m a -> CircularT k v m b
forall k v (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> CircularT k v m a -> CircularT k v m b -> CircularT k v m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall k v (m :: * -> *) a. Monad m => a -> CircularT k v m a
pure :: forall a. a -> CircularT k v m a
$c<*> :: forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m (a -> b) -> CircularT k v m a -> CircularT k v m b
<*> :: forall a b.
CircularT k v m (a -> b) -> CircularT k v m a -> CircularT k v m b
$cliftA2 :: forall k v (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> CircularT k v m a -> CircularT k v m b -> CircularT k v m c
liftA2 :: forall a b c.
(a -> b -> c)
-> CircularT k v m a -> CircularT k v m b -> CircularT k v m c
$c*> :: forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
*> :: forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
$c<* :: forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m a
<* :: forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m a
Applicative,
      Applicative (CircularT k v m)
Applicative (CircularT k v m)
-> (forall a b.
    CircularT k v m a -> (a -> CircularT k v m b) -> CircularT k v m b)
-> (forall a b.
    CircularT k v m a -> CircularT k v m b -> CircularT k v m b)
-> (forall a. a -> CircularT k v m a)
-> Monad (CircularT k v m)
forall a. a -> CircularT k v m a
forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
forall a b.
CircularT k v m a -> (a -> CircularT k v m b) -> CircularT k v m b
forall k v (m :: * -> *). Monad m => Applicative (CircularT k v m)
forall k v (m :: * -> *) a. Monad m => a -> CircularT k v m a
forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> (a -> CircularT k v m b) -> CircularT k v m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> (a -> CircularT k v m b) -> CircularT k v m b
>>= :: forall a b.
CircularT k v m a -> (a -> CircularT k v m b) -> CircularT k v m b
$c>> :: forall k v (m :: * -> *) a b.
Monad m =>
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
>> :: forall a b.
CircularT k v m a -> CircularT k v m b -> CircularT k v m b
$creturn :: forall k v (m :: * -> *) a. Monad m => a -> CircularT k v m a
return :: forall a. a -> CircularT k v m a
Monad,
      MonadError e,
      MonadReader r,
      MonadWriter w
    )

instance MonadTrans (CircularT k v) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> CircularT k v m a
lift = StateT (HashMap k v) m a -> CircularT k v m a
forall k v (m :: * -> *) a.
StateT (HashMap k v) m a -> CircularT k v m a
CircularT (StateT (HashMap k v) m a -> CircularT k v m a)
-> (m a -> StateT (HashMap k v) m a) -> m a -> CircularT k v m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT (HashMap k v) m a
forall (m :: * -> *) a. Monad m => m a -> StateT (HashMap k v) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Allow code in 'CircularT' to have access to any underlying state
-- capabilities, hiding the fact that 'CircularT' itself is a state monad.
instance (MonadState s m) => MonadState s (CircularT k v m) where
  get :: CircularT k v m s
get = m s -> CircularT k v m s
forall (m :: * -> *) a. Monad m => m a -> CircularT k v m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
  put :: s -> CircularT k v m ()
put s
x = m () -> CircularT k v m ()
forall (m :: * -> *) a. Monad m => m a -> CircularT k v m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> CircularT k v m ()) -> m () -> CircularT k v m ()
forall a b. (a -> b) -> a -> b
$ s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put s
x

-- | Runs a computation in 'CircularT'.
runCircularT :: (Hashable k, MonadFix m) => CircularT k v m a -> m a
runCircularT :: forall k (m :: * -> *) v a.
(Hashable k, MonadFix m) =>
CircularT k v m a -> m a
runCircularT (CircularT StateT (HashMap k v) m a
m) = StateT (HashMap k v) m a -> HashMap k v -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT (HashMap k v) m a
m HashMap k v
forall a. Monoid a => a
mempty

-- | Cache a computation under a given key.
--
-- For a given key @k@, and a computation in 'CircularT' that yields a value of
-- type @v@, return an action that builds said value @v@ but that prevents
-- cycles by looking into and populating a stateful cache.
withCircular ::
  (Hashable k, MonadFix m) =>
  k ->
  CircularT k v m v ->
  CircularT k v m v
withCircular :: forall k (m :: * -> *) v.
(Hashable k, MonadFix m) =>
k -> CircularT k v m v -> CircularT k v m v
withCircular k
key (CircularT StateT (HashMap k v) m v
action) = StateT (HashMap k v) m v -> CircularT k v m v
forall k v (m :: * -> *) a.
StateT (HashMap k v) m a -> CircularT k v m a
CircularT do
  HashMap k v
cache <- StateT (HashMap k v) m (HashMap k v)
forall s (m :: * -> *). MonadState s m => m s
get
  case k -> HashMap k v -> Maybe v
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup k
key HashMap k v
cache of
    -- If the key is already present in the cache, that means we have
    -- already encountered that particular key in our process; no need to use the
    -- @action@.
    Just v
value -> v -> StateT (HashMap k v) m v
forall a. a -> StateT (HashMap k v) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
value
    -- Otherwise, it means we haven't encountered it yet: we need to build it
    -- and cache the result.
    Maybe v
Nothing -> mdo
      -- Insert a thunk referencing the eventual actual value in the cache; we
      -- need the cache to be a lazy map for this to work.
      (HashMap k v -> HashMap k v) -> StateT (HashMap k v) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((HashMap k v -> HashMap k v) -> StateT (HashMap k v) m ())
-> (HashMap k v -> HashMap k v) -> StateT (HashMap k v) m ()
forall a b. (a -> b) -> a -> b
$ k -> v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
Map.insert k
key v
actualValue
      -- We compute the actual value by evaluating the action. This will only
      -- happen once per key. Note that we use 'actualValue' before it is built:
      -- this is why we need 'MonadFix' and "recursive do".
      v
actualValue <- StateT (HashMap k v) m v
action
      -- And we return the value!
      v -> StateT (HashMap k v) m v
forall a. a -> StateT (HashMap k v) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure v
actualValue

-- We don't want to rely on Hasura.Prelude in "third-party" libraries.
{-# ANN withCircular ("HLint: ignore Use onNothing" :: String) #-}