module Hasura.GraphQL.Execute.Subscription.TMap
  ( TMap,
    new,
    reset,
    null,
    lookup,
    insert,
    delete,
    toList,
    replace,
    union,
    filterWithKey,
    getMap,
  )
where

import Control.Concurrent.STM
import Data.HashMap.Strict qualified as Map
import Hasura.Prelude hiding (lookup, null, toList, union)

-- | A coarse-grained transactional map implemented by simply wrapping a 'Map.HashMap' in a 'TVar'.
-- Compared to "StmContainers.Map", this provides much faster iteration over the elements at the
-- cost of significantly increased contention on writes.
newtype TMap k v = TMap {TMap k v -> TVar (HashMap k v)
unTMap :: TVar (Map.HashMap k v)}

new :: STM (TMap k v)
new :: STM (TMap k v)
new = TVar (HashMap k v) -> TMap k v
forall k v. TVar (HashMap k v) -> TMap k v
TMap (TVar (HashMap k v) -> TMap k v)
-> STM (TVar (HashMap k v)) -> STM (TMap k v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HashMap k v -> STM (TVar (HashMap k v))
forall a. a -> STM (TVar a)
newTVar HashMap k v
forall k v. HashMap k v
Map.empty

reset :: TMap k v -> STM ()
reset :: TMap k v -> STM ()
reset = (TVar (HashMap k v) -> HashMap k v -> STM ())
-> HashMap k v -> TVar (HashMap k v) -> STM ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip TVar (HashMap k v) -> HashMap k v -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar HashMap k v
forall k v. HashMap k v
Map.empty (TVar (HashMap k v) -> STM ())
-> (TMap k v -> TVar (HashMap k v)) -> TMap k v -> STM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap

null :: TMap k v -> STM Bool
null :: TMap k v -> STM Bool
null = (HashMap k v -> Bool) -> STM (HashMap k v) -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HashMap k v -> Bool
forall k v. HashMap k v -> Bool
Map.null (STM (HashMap k v) -> STM Bool)
-> (TMap k v -> STM (HashMap k v)) -> TMap k v -> STM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> (TMap k v -> TVar (HashMap k v))
-> TMap k v
-> STM (HashMap k v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap

lookup :: (Eq k, Hashable k) => k -> TMap k v -> STM (Maybe v)
lookup :: k -> TMap k v -> STM (Maybe v)
lookup k
k = (HashMap k v -> Maybe v) -> STM (HashMap k v) -> STM (Maybe v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (k -> HashMap k v -> Maybe v
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup k
k) (STM (HashMap k v) -> STM (Maybe v))
-> (TMap k v -> STM (HashMap k v)) -> TMap k v -> STM (Maybe v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> (TMap k v -> TVar (HashMap k v))
-> TMap k v
-> STM (HashMap k v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap

insert :: (Eq k, Hashable k) => v -> k -> TMap k v -> STM ()
insert :: v -> k -> TMap k v -> STM ()
insert !v
v k
k TMap k v
mapTv = TVar (HashMap k v) -> (HashMap k v -> HashMap k v) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapTv) ((HashMap k v -> HashMap k v) -> STM ())
-> (HashMap k v -> HashMap k v) -> STM ()
forall a b. (a -> b) -> a -> b
$ k -> v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
Map.insert k
k v
v

delete :: (Eq k, Hashable k) => k -> TMap k v -> STM ()
delete :: k -> TMap k v -> STM ()
delete k
k TMap k v
mapTv = TVar (HashMap k v) -> (HashMap k v -> HashMap k v) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapTv) ((HashMap k v -> HashMap k v) -> STM ())
-> (HashMap k v -> HashMap k v) -> STM ()
forall a b. (a -> b) -> a -> b
$ k -> HashMap k v -> HashMap k v
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
Map.delete k
k

toList :: TMap k v -> STM [(k, v)]
toList :: TMap k v -> STM [(k, v)]
toList = (HashMap k v -> [(k, v)]) -> STM (HashMap k v) -> STM [(k, v)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HashMap k v -> [(k, v)]
forall k v. HashMap k v -> [(k, v)]
Map.toList (STM (HashMap k v) -> STM [(k, v)])
-> (TMap k v -> STM (HashMap k v)) -> TMap k v -> STM [(k, v)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> (TMap k v -> TVar (HashMap k v))
-> TMap k v
-> STM (HashMap k v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap

filterWithKey :: (k -> v -> Bool) -> TMap k v -> STM ()
filterWithKey :: (k -> v -> Bool) -> TMap k v -> STM ()
filterWithKey k -> v -> Bool
f TMap k v
mapTV = TVar (HashMap k v) -> (HashMap k v -> HashMap k v) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' (TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapTV) ((HashMap k v -> HashMap k v) -> STM ())
-> (HashMap k v -> HashMap k v) -> STM ()
forall a b. (a -> b) -> a -> b
$ (k -> v -> Bool) -> HashMap k v -> HashMap k v
forall k v. (k -> v -> Bool) -> HashMap k v -> HashMap k v
Map.filterWithKey k -> v -> Bool
f

replace :: TMap k v -> Map.HashMap k v -> STM ()
replace :: TMap k v -> HashMap k v -> STM ()
replace TMap k v
mapTV HashMap k v
v = STM (HashMap k v) -> STM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (STM (HashMap k v) -> STM ()) -> STM (HashMap k v) -> STM ()
forall a b. (a -> b) -> a -> b
$ TVar (HashMap k v) -> HashMap k v -> STM (HashMap k v)
forall a. TVar a -> a -> STM a
swapTVar (TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapTV) HashMap k v
v

union :: (Eq k, Hashable k) => TMap k v -> TMap k v -> STM (TMap k v)
union :: TMap k v -> TMap k v -> STM (TMap k v)
union TMap k v
mapA TMap k v
mapB = do
  HashMap k v
l <- TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> TVar (HashMap k v) -> STM (HashMap k v)
forall a b. (a -> b) -> a -> b
$ TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapA
  HashMap k v
r <- TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> TVar (HashMap k v) -> STM (HashMap k v)
forall a b. (a -> b) -> a -> b
$ TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap TMap k v
mapB
  TVar (HashMap k v) -> TMap k v
forall k v. TVar (HashMap k v) -> TMap k v
TMap (TVar (HashMap k v) -> TMap k v)
-> STM (TVar (HashMap k v)) -> STM (TMap k v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HashMap k v -> STM (TVar (HashMap k v))
forall a. a -> STM (TVar a)
newTVar (HashMap k v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k v -> HashMap k v
Map.union HashMap k v
l HashMap k v
r)

getMap :: TMap k v -> STM (Map.HashMap k v)
getMap :: TMap k v -> STM (HashMap k v)
getMap = TVar (HashMap k v) -> STM (HashMap k v)
forall a. TVar a -> STM a
readTVar (TVar (HashMap k v) -> STM (HashMap k v))
-> (TMap k v -> TVar (HashMap k v))
-> TMap k v
-> STM (HashMap k v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TMap k v -> TVar (HashMap k v)
forall k v. TMap k v -> TVar (HashMap k v)
unTMap