{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE TemplateHaskell #-}

module Hasura.GraphQL.Transport.WebSocket.Server
  ( AcceptWith (AcceptWith),
    HasuraServerApp,
    MessageDetails (MessageDetails),
    MonadWSLog (..),
    OnConnH,
    WSActions (..),
    WSConn,
    WSEvent (EMessageSent),
    WSEventInfo (WSEventInfo, _wseiEventType, _wseiOperationId, _wseiOperationName, _wseiParameterizedQueryHash, _wseiQueryExecutionTime, _wseiResponseSize),
    WSHandlers (WSHandlers),
    WSId,
    WSKeepAliveMessageAction,
    WSLog (WSLog),
    WSOnErrorMessageAction,
    WSQueueResponse (WSQueueResponse),
    WSServer,
    closeConn,
    createServerApp,
    createWSServer,
    getData,
    getRawWebSocketConnection,
    getWSId,
    onClientMessageParseErrorText,
    onConnInitErrorText,
    sendMsg,
    shutdown,

    -- * exported for testing
    mkUnsafeWSId,
  )
where

import Control.Concurrent.Async qualified as A
import Control.Concurrent.Async.Lifted.Safe qualified as LA
import Control.Concurrent.STM qualified as STM
import Control.Exception.Lifted
import Control.Monad.Trans.Control qualified as MC
import Data.Aeson qualified as J
import Data.Aeson.Casing qualified as J
import Data.Aeson.TH qualified as J
import Data.ByteString.Char8 qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.CaseInsensitive qualified as CI
import Data.SerializableBlob qualified as SB
import Data.String
import Data.UUID qualified as UUID
import Data.UUID.V4 qualified as UUID
import Data.Word (Word16)
import GHC.AssertNF.CPP
import GHC.Int (Int64)
import Hasura.GraphQL.ParameterizedQueryHash (ParameterizedQueryHash)
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.Logging qualified as L
import Hasura.Prelude
import Hasura.RQL.Types.Numeric qualified as Numeric
import Hasura.Server.Init.Config (WSConnectionInitTimeout (..))
import ListT qualified
import Network.Wai.Extended (IpAddress)
import Network.WebSockets qualified as WS
import StmContainers.Map qualified as STMMap
import System.IO.Error qualified as E

newtype WSId = WSId {WSId -> UUID
unWSId :: UUID.UUID}
  deriving (Int -> WSId -> ShowS
[WSId] -> ShowS
WSId -> String
(Int -> WSId -> ShowS)
-> (WSId -> String) -> ([WSId] -> ShowS) -> Show WSId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSId] -> ShowS
$cshowList :: [WSId] -> ShowS
show :: WSId -> String
$cshow :: WSId -> String
showsPrec :: Int -> WSId -> ShowS
$cshowsPrec :: Int -> WSId -> ShowS
Show, WSId -> WSId -> Bool
(WSId -> WSId -> Bool) -> (WSId -> WSId -> Bool) -> Eq WSId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WSId -> WSId -> Bool
$c/= :: WSId -> WSId -> Bool
== :: WSId -> WSId -> Bool
$c== :: WSId -> WSId -> Bool
Eq, Int -> WSId -> Int
WSId -> Int
(Int -> WSId -> Int) -> (WSId -> Int) -> Hashable WSId
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: WSId -> Int
$chash :: WSId -> Int
hashWithSalt :: Int -> WSId -> Int
$chashWithSalt :: Int -> WSId -> Int
Hashable)

mkUnsafeWSId :: UUID.UUID -> WSId
mkUnsafeWSId :: UUID -> WSId
mkUnsafeWSId = UUID -> WSId
WSId

instance J.ToJSON WSId where
  toJSON :: WSId -> Value
toJSON (WSId UUID
uuid) =
    Text -> Value
forall a. ToJSON a => a -> Value
J.toJSON (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ UUID -> Text
UUID.toText UUID
uuid

-- | Websocket message and other details
data MessageDetails = MessageDetails
  { MessageDetails -> SerializableBlob
_mdMessage :: !SB.SerializableBlob,
    MessageDetails -> Int64
_mdMessageSize :: !Int64
  }
  deriving (Int -> MessageDetails -> ShowS
[MessageDetails] -> ShowS
MessageDetails -> String
(Int -> MessageDetails -> ShowS)
-> (MessageDetails -> String)
-> ([MessageDetails] -> ShowS)
-> Show MessageDetails
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MessageDetails] -> ShowS
$cshowList :: [MessageDetails] -> ShowS
show :: MessageDetails -> String
$cshow :: MessageDetails -> String
showsPrec :: Int -> MessageDetails -> ShowS
$cshowsPrec :: Int -> MessageDetails -> ShowS
Show)

$(J.deriveToJSON hasuraJSON ''MessageDetails)

data WSEvent
  = EConnectionRequest
  | EAccepted
  | ERejected
  | EMessageReceived !MessageDetails
  | EMessageSent !MessageDetails
  | EJwtExpired
  | ECloseReceived
  | ECloseSent !SB.SerializableBlob
  | EClosed
  deriving (Int -> WSEvent -> ShowS
[WSEvent] -> ShowS
WSEvent -> String
(Int -> WSEvent -> ShowS)
-> (WSEvent -> String) -> ([WSEvent] -> ShowS) -> Show WSEvent
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSEvent] -> ShowS
$cshowList :: [WSEvent] -> ShowS
show :: WSEvent -> String
$cshow :: WSEvent -> String
showsPrec :: Int -> WSEvent -> ShowS
$cshowsPrec :: Int -> WSEvent -> ShowS
Show)

$( J.deriveToJSON
     J.defaultOptions
       { J.constructorTagModifier = J.snakeCase . drop 1,
         J.sumEncoding = J.TaggedObject "type" "detail"
       }
     ''WSEvent
 )

-- extra websocket event info
data WSEventInfo = WSEventInfo
  { WSEventInfo -> Maybe ServerMsgType
_wseiEventType :: !(Maybe ServerMsgType),
    WSEventInfo -> Maybe OperationId
_wseiOperationId :: !(Maybe OperationId),
    WSEventInfo -> Maybe OperationName
_wseiOperationName :: !(Maybe OperationName),
    WSEventInfo -> Maybe Double
_wseiQueryExecutionTime :: !(Maybe Double),
    WSEventInfo -> Maybe Int64
_wseiResponseSize :: !(Maybe Int64),
    WSEventInfo -> Maybe ParameterizedQueryHash
_wseiParameterizedQueryHash :: !(Maybe ParameterizedQueryHash)
  }
  deriving (Int -> WSEventInfo -> ShowS
[WSEventInfo] -> ShowS
WSEventInfo -> String
(Int -> WSEventInfo -> ShowS)
-> (WSEventInfo -> String)
-> ([WSEventInfo] -> ShowS)
-> Show WSEventInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSEventInfo] -> ShowS
$cshowList :: [WSEventInfo] -> ShowS
show :: WSEventInfo -> String
$cshow :: WSEventInfo -> String
showsPrec :: Int -> WSEventInfo -> ShowS
$cshowsPrec :: Int -> WSEventInfo -> ShowS
Show, WSEventInfo -> WSEventInfo -> Bool
(WSEventInfo -> WSEventInfo -> Bool)
-> (WSEventInfo -> WSEventInfo -> Bool) -> Eq WSEventInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WSEventInfo -> WSEventInfo -> Bool
$c/= :: WSEventInfo -> WSEventInfo -> Bool
== :: WSEventInfo -> WSEventInfo -> Bool
$c== :: WSEventInfo -> WSEventInfo -> Bool
Eq)

$( J.deriveToJSON
     J.defaultOptions
       { J.fieldLabelModifier = J.snakeCase . drop 5,
         J.omitNothingFields = True
       }
     ''WSEventInfo
 )

data WSLog = WSLog
  { WSLog -> WSId
_wslWebsocketId :: !WSId,
    WSLog -> WSEvent
_wslEvent :: !WSEvent,
    WSLog -> Maybe WSEventInfo
_wslMetadata :: !(Maybe WSEventInfo)
  }
  deriving (Int -> WSLog -> ShowS
[WSLog] -> ShowS
WSLog -> String
(Int -> WSLog -> ShowS)
-> (WSLog -> String) -> ([WSLog] -> ShowS) -> Show WSLog
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSLog] -> ShowS
$cshowList :: [WSLog] -> ShowS
show :: WSLog -> String
$cshow :: WSLog -> String
showsPrec :: Int -> WSLog -> ShowS
$cshowsPrec :: Int -> WSLog -> ShowS
Show)

$( J.deriveToJSON
     J.defaultOptions
       { J.fieldLabelModifier = J.snakeCase . drop 4,
         J.omitNothingFields = True
       }
     ''WSLog
 )

class Monad m => MonadWSLog m where
  -- | Takes WS server log data and logs it
  -- logWSServer
  logWSLog :: L.Logger L.Hasura -> WSLog -> m ()

instance MonadWSLog m => MonadWSLog (ExceptT e m) where
  logWSLog :: Logger Hasura -> WSLog -> ExceptT e m ()
logWSLog Logger Hasura
l WSLog
ws = m () -> ExceptT e m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ExceptT e m ()) -> m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
l WSLog
ws

instance MonadWSLog m => MonadWSLog (ReaderT r m) where
  logWSLog :: Logger Hasura -> WSLog -> ReaderT r m ()
logWSLog Logger Hasura
l WSLog
ws = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ()) -> m () -> ReaderT r m ()
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
l WSLog
ws

instance L.ToEngineLog WSLog L.Hasura where
  toEngineLog :: WSLog -> (LogLevel, EngineLogType Hasura, Value)
toEngineLog WSLog
wsLog =
    (LogLevel
L.LevelDebug, InternalLogTypes -> EngineLogType Hasura
L.ELTInternal InternalLogTypes
L.ILTWsServer, WSLog -> Value
forall a. ToJSON a => a -> Value
J.toJSON WSLog
wsLog)

data WSQueueResponse = WSQueueResponse
  { WSQueueResponse -> ByteString
_wsqrMessage :: !BL.ByteString,
    -- | extra metadata that we use for other actions, such as print log
    -- we don't want to inlcude them into websocket message payload
    WSQueueResponse -> Maybe WSEventInfo
_wsqrEventInfo :: !(Maybe WSEventInfo)
  }

data WSConn a = WSConn
  { WSConn a -> WSId
_wcConnId :: !WSId,
    WSConn a -> Logger Hasura
_wcLogger :: !(L.Logger L.Hasura),
    WSConn a -> Connection
_wcConnRaw :: !WS.Connection,
    WSConn a -> TQueue WSQueueResponse
_wcSendQ :: !(STM.TQueue WSQueueResponse),
    WSConn a -> a
_wcExtraData :: !a
  }

getRawWebSocketConnection :: WSConn a -> WS.Connection
getRawWebSocketConnection :: WSConn a -> Connection
getRawWebSocketConnection = WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw

getData :: WSConn a -> a
getData :: WSConn a -> a
getData = WSConn a -> a
forall a. WSConn a -> a
_wcExtraData

getWSId :: WSConn a -> WSId
getWSId :: WSConn a -> WSId
getWSId = WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId

closeConn :: WSConn a -> BL.ByteString -> IO ()
closeConn :: WSConn a -> ByteString -> IO ()
closeConn WSConn a
wsConn = WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
1000 -- 1000 is "normal close"

-- | Closes a connection with code 1012, which means "Server is restarting"
-- good clients will implement a retry logic with a backoff of a few seconds
forceConnReconnect :: MonadIO m => WSConn a -> BL.ByteString -> m ()
forceConnReconnect :: WSConn a -> ByteString -> m ()
forceConnReconnect WSConn a
wsConn ByteString
bs = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
1012 ByteString
bs

closeConnWithCode :: WSConn a -> Word16 -> BL.ByteString -> IO ()
closeConnWithCode :: WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
code ByteString
bs = do
  ((\Logger Hasura
x -> Logger Hasura
-> forall a (m :: * -> *).
   (ToEngineLog a Hasura, MonadIO m) =>
   a -> m ()
forall impl.
Logger impl
-> forall a (m :: * -> *).
   (ToEngineLog a impl, MonadIO m) =>
   a -> m ()
L.unLogger Logger Hasura
x) (Logger Hasura -> WSLog -> IO ())
-> (WSConn a -> Logger Hasura) -> WSConn a -> WSLog -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WSConn a -> Logger Hasura
forall a. WSConn a -> Logger Hasura
_wcLogger) WSConn a
wsConn (WSLog -> IO ()) -> WSLog -> IO ()
forall a b. (a -> b) -> a -> b
$
    WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) (SerializableBlob -> WSEvent
ECloseSent (SerializableBlob -> WSEvent) -> SerializableBlob -> WSEvent
forall a b. (a -> b) -> a -> b
$ ByteString -> SerializableBlob
SB.fromLBS ByteString
bs) Maybe WSEventInfo
forall a. Maybe a
Nothing
  Connection -> Word16 -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
WS.sendCloseCode (WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw WSConn a
wsConn) Word16
code ByteString
bs

-- writes to a queue instead of the raw connection
-- so that sendMsg doesn't block
sendMsg :: WSConn a -> WSQueueResponse -> IO ()
sendMsg :: WSConn a -> WSQueueResponse -> IO ()
sendMsg WSConn a
wsConn !WSQueueResponse
resp = do
  String
String -> WSQueueResponse -> IO ()
forall a. String -> a -> IO ()
$assertNFHere WSQueueResponse
resp -- so we don't write thunks to mutable vars
  STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue WSQueueResponse -> WSQueueResponse -> STM ()
forall a. TQueue a -> a -> STM ()
STM.writeTQueue (WSConn a -> TQueue WSQueueResponse
forall a. WSConn a -> TQueue WSQueueResponse
_wcSendQ WSConn a
wsConn) WSQueueResponse
resp

type ConnMap a = STMMap.Map WSId (WSConn a)

data ServerStatus a
  = AcceptingConns !(ConnMap a)
  | ShuttingDown

data WSServer a = WSServer
  { WSServer a -> Logger Hasura
_wssLogger :: !(L.Logger L.Hasura),
    -- | See e.g. createServerApp.onAccept for how we use STM to preserve consistency
    WSServer a -> TVar (ServerStatus a)
_wssStatus :: !(STM.TVar (ServerStatus a))
  }

createWSServer :: L.Logger L.Hasura -> STM.STM (WSServer a)
createWSServer :: Logger Hasura -> STM (WSServer a)
createWSServer Logger Hasura
logger = do
  Map WSId (WSConn a)
connMap <- STM (Map WSId (WSConn a))
forall key value. STM (Map key value)
STMMap.new
  TVar (ServerStatus a)
serverStatus <- ServerStatus a -> STM (TVar (ServerStatus a))
forall a. a -> STM (TVar a)
STM.newTVar (Map WSId (WSConn a) -> ServerStatus a
forall a. ConnMap a -> ServerStatus a
AcceptingConns Map WSId (WSConn a)
connMap)
  WSServer a -> STM (WSServer a)
forall (m :: * -> *) a. Monad m => a -> m a
return (WSServer a -> STM (WSServer a)) -> WSServer a -> STM (WSServer a)
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> TVar (ServerStatus a) -> WSServer a
forall a. Logger Hasura -> TVar (ServerStatus a) -> WSServer a
WSServer Logger Hasura
logger TVar (ServerStatus a)
serverStatus

closeAllWith ::
  (BL.ByteString -> WSConn a -> IO ()) ->
  BL.ByteString ->
  [(WSId, WSConn a)] ->
  IO ()
closeAllWith :: (ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
closeAllWith ByteString -> WSConn a -> IO ()
closer ByteString
msg [(WSId, WSConn a)]
conns =
  IO [()] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO [()] -> IO ()) -> IO [()] -> IO ()
forall a b. (a -> b) -> a -> b
$ ((WSId, WSConn a) -> IO ()) -> [(WSId, WSConn a)] -> IO [()]
forall (t :: * -> *) a b.
Traversable t =>
(a -> IO b) -> t a -> IO (t b)
A.mapConcurrently (ByteString -> WSConn a -> IO ()
closer ByteString
msg (WSConn a -> IO ())
-> ((WSId, WSConn a) -> WSConn a) -> (WSId, WSConn a) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WSId, WSConn a) -> WSConn a
forall a b. (a, b) -> b
snd) [(WSId, WSConn a)]
conns

-- | Resets the current connections map to an empty one if the server is
-- running and returns the list of connections that were in the map
-- before flushing it.
flushConnMap :: STM.TVar (ServerStatus a) -> STM.STM [(WSId, WSConn a)]
flushConnMap :: TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
flushConnMap TVar (ServerStatus a)
serverStatus = do
  ServerStatus a
status <- TVar (ServerStatus a) -> STM (ServerStatus a)
forall a. TVar a -> STM a
STM.readTVar TVar (ServerStatus a)
serverStatus
  case ServerStatus a
status of
    AcceptingConns ConnMap a
connMap -> do
      [(WSId, WSConn a)]
conns <- ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)]
forall (m :: * -> *) a. Monad m => ListT m a -> m [a]
ListT.toList (ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)])
-> ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)]
forall a b. (a -> b) -> a -> b
$ ConnMap a -> ListT STM (WSId, WSConn a)
forall key value. Map key value -> ListT STM (key, value)
STMMap.listT ConnMap a
connMap
      ConnMap a -> STM ()
forall key value. Map key value -> STM ()
STMMap.reset ConnMap a
connMap
      [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(WSId, WSConn a)]
conns
    ServerStatus a
ShuttingDown -> [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall (m :: * -> *) a. Monad m => a -> m a
return []

data AcceptWith a = AcceptWith
  { AcceptWith a -> a
_awData :: !a,
    AcceptWith a -> AcceptRequest
_awReq :: !WS.AcceptRequest,
    AcceptWith a -> WSConn a -> IO ()
_awKeepAlive :: !(WSConn a -> IO ()),
    AcceptWith a -> WSConn a -> IO ()
_awOnJwtExpiry :: !(WSConn a -> IO ())
  }

-- | These set of functions or message handlers is used by the
--   server while communicating with the client. They are particularly
--   useful for the case when the messages being sent to the client
--   are different for each of the sub-protocol(s) supported by the server.
type WSKeepAliveMessageAction a = WSConn a -> IO ()

type WSPostExecErrMessageAction a = WSConn a -> OperationId -> GQExecError -> IO ()

type WSOnErrorMessageAction a = WSConn a -> ConnErrMsg -> Maybe String -> IO ()

type WSCloseConnAction a = WSConn a -> OperationId -> String -> IO ()

-- | Used for specific actions within the `onConn` and `onMessage` handlers
data WSActions a = WSActions
  { WSActions a -> WSPostExecErrMessageAction a
_wsaPostExecErrMessageAction :: !(WSPostExecErrMessageAction a),
    WSActions a -> WSOnErrorMessageAction a
_wsaOnErrorMessageAction :: !(WSOnErrorMessageAction a),
    WSActions a -> WSCloseConnAction a
_wsaConnectionCloseAction :: !(WSCloseConnAction a),
    -- | NOTE: keep alive action was made redundant because we need to send this message
    -- after the connection has been successfully established after `connection_init`
    WSActions a -> WSKeepAliveMessageAction a
_wsaKeepAliveAction :: !(WSKeepAliveMessageAction a),
    WSActions a -> DataMsg -> ServerMsg
_wsaGetDataMessageType :: !(DataMsg -> ServerMsg),
    WSActions a -> AcceptRequest
_wsaAcceptRequest :: !WS.AcceptRequest,
    WSActions a -> [Value] -> Value
_wsaErrorMsgFormat :: !([J.Value] -> J.Value)
  }

-- | to be used with `WSOnErrorMessageAction`
onClientMessageParseErrorText :: Maybe String
onClientMessageParseErrorText :: Maybe String
onClientMessageParseErrorText = String -> Maybe String
forall a. a -> Maybe a
Just String
"Parsing client message failed: "

-- | to be used with `WSOnErrorMessageAction`
onConnInitErrorText :: Maybe String
onConnInitErrorText :: Maybe String
onConnInitErrorText = String -> Maybe String
forall a. a -> Maybe a
Just String
"Connection initialization failed: "

type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> WSActions a -> m (Either WS.RejectRequest (AcceptWith a))

-- type OnMessageH m a = WSConn a -> BL.ByteString -> WSActions a -> m ()

type OnCloseH m a = WSConn a -> m ()

-- | aka generalized 'WS.ServerApp' over @m@, which takes an IPAddress
type HasuraServerApp m = IpAddress -> WS.PendingConnection -> m ()

-- | NOTE: The types of `_hOnConn` and `_hOnMessage` were updated from `OnConnH` and `OnMessageH`
-- because we needed to pass the subprotcol here to these methods to eventually get to `OnConnH` and `OnMessageH`.
-- Please see `createServerApp` to get a better understanding of how these handlers are used.
data WSHandlers m a = WSHandlers
  { WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
_hOnConn :: (WSId -> WS.RequestHead -> IpAddress -> WSSubProtocol -> m (Either WS.RejectRequest (AcceptWith a))),
    WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
_hOnMessage :: (WSConn a -> BL.ByteString -> WSSubProtocol -> m ()),
    WSHandlers m a -> OnCloseH m a
_hOnClose :: OnCloseH m a
  }

createServerApp ::
  (MonadIO m, MC.MonadBaseControl IO m, LA.Forall (LA.Pure m), MonadWSLog m) =>
  WSConnectionInitTimeout ->
  WSServer a ->
  -- | user provided handlers
  WSHandlers m a ->
  -- | aka WS.ServerApp
  HasuraServerApp m
{-# INLINE createServerApp #-}
createServerApp :: WSConnectionInitTimeout
-> WSServer a -> WSHandlers m a -> HasuraServerApp m
createServerApp WSConnectionInitTimeout
wsConnInitTimeout (WSServer logger :: Logger Hasura
logger@(L.Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog) TVar (ServerStatus a)
serverStatus) WSHandlers m a
wsHandlers !IpAddress
ipAddress !PendingConnection
pendingConn = do
  WSId
wsId <- UUID -> WSId
WSId (UUID -> WSId) -> m UUID -> m WSId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID -> m UUID
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UUID
UUID.nextRandom
  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
EConnectionRequest Maybe WSEventInfo
forall a. Maybe a
Nothing
  -- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection
  -- if the client doesn't send a `connection_init` message within the timeout period
  WSConnInitTimer
wsConnInitTimer <- IO WSConnInitTimer -> m WSConnInitTimer
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSConnInitTimer -> m WSConnInitTimer)
-> IO WSConnInitTimer -> m WSConnInitTimer
forall a b. (a -> b) -> a -> b
$ Seconds -> IO WSConnInitTimer
getNewWSTimer (NonNegative Seconds -> Seconds
forall a. NonNegative a -> a
Numeric.getNonNegative (NonNegative Seconds -> Seconds) -> NonNegative Seconds -> Seconds
forall a b. (a -> b) -> a -> b
$ WSConnectionInitTimeout -> NonNegative Seconds
unWSConnectionInitTimeout WSConnectionInitTimeout
wsConnInitTimeout)
  ServerStatus a
status <- IO (ServerStatus a) -> m (ServerStatus a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ServerStatus a) -> m (ServerStatus a))
-> IO (ServerStatus a) -> m (ServerStatus a)
forall a b. (a -> b) -> a -> b
$ TVar (ServerStatus a) -> IO (ServerStatus a)
forall a. TVar a -> IO a
STM.readTVarIO TVar (ServerStatus a)
serverStatus
  case ServerStatus a
status of
    AcceptingConns ConnMap a
_ -> m () -> m ()
logUnexpectedExceptions (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Either RejectRequest (AcceptWith a)
onConnRes <- WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
connHandler WSId
wsId RequestHead
reqHead IpAddress
ipAddress WSSubProtocol
subProtocol
      (RejectRequest -> m ())
-> (AcceptWith a -> m ())
-> Either RejectRequest (AcceptWith a)
-> m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (WSId -> RejectRequest -> m ()
onReject WSId
wsId) (WSConnInitTimer -> WSId -> AcceptWith a -> m ()
onAccept WSConnInitTimer
wsConnInitTimer WSId
wsId) Either RejectRequest (AcceptWith a)
onConnRes
    ServerStatus a
ShuttingDown ->
      WSId -> RejectRequest -> m ()
onReject WSId
wsId RejectRequest
shuttingDownReject
  where
    reqHead :: RequestHead
reqHead = PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn

    getSubProtocolHeader :: RequestHead -> [(CI ByteString, ByteString)]
getSubProtocolHeader RequestHead
rhdrs =
      ((CI ByteString, ByteString) -> Bool)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(CI ByteString
x, ByteString
_) -> CI ByteString
x CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk (ByteString -> CI ByteString)
-> (String -> ByteString) -> String -> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B.pack (String -> CI ByteString) -> String -> CI ByteString
forall a b. (a -> b) -> a -> b
$ String
"Sec-WebSocket-Protocol")) ([(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)])
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$ RequestHead -> [(CI ByteString, ByteString)]
WS.requestHeaders RequestHead
rhdrs

    subProtocol :: WSSubProtocol
subProtocol = case RequestHead -> [(CI ByteString, ByteString)]
getSubProtocolHeader RequestHead
reqHead of
      [(CI ByteString, ByteString)
sph] -> String -> WSSubProtocol
toWSSubProtocol (String -> WSSubProtocol)
-> ((CI ByteString, ByteString) -> String)
-> (CI ByteString, ByteString)
-> WSSubProtocol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B.unpack (ByteString -> String)
-> ((CI ByteString, ByteString) -> ByteString)
-> (CI ByteString, ByteString)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CI ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((CI ByteString, ByteString) -> WSSubProtocol)
-> (CI ByteString, ByteString) -> WSSubProtocol
forall a b. (a -> b) -> a -> b
$ (CI ByteString, ByteString)
sph
      [(CI ByteString, ByteString)]
_ -> WSSubProtocol
Apollo -- NOTE: we default to the apollo implemenation
    connHandler :: WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
connHandler = WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
forall (m :: * -> *) a.
WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
_hOnConn WSHandlers m a
wsHandlers
    messageHandler :: WSConn a -> ByteString -> WSSubProtocol -> m ()
messageHandler = WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
forall (m :: * -> *) a.
WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
_hOnMessage WSHandlers m a
wsHandlers
    closeHandler :: OnCloseH m a
closeHandler = WSHandlers m a -> OnCloseH m a
forall (m :: * -> *) a. WSHandlers m a -> OnCloseH m a
_hOnClose WSHandlers m a
wsHandlers

    -- It's not clear what the unexpected exception handling story here should be. So at
    -- least log properly and re-raise:
    logUnexpectedExceptions :: m () -> m ()
logUnexpectedExceptions = (SomeException -> m ()) -> m () -> m ()
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
(e -> m a) -> m a -> m a
handle ((SomeException -> m ()) -> m () -> m ())
-> (SomeException -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) -> do
      UnstructuredLog -> m ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog (UnstructuredLog -> m ()) -> UnstructuredLog -> m ()
forall a b. (a -> b) -> a -> b
$
        LogLevel -> SerializableBlob -> UnstructuredLog
L.UnstructuredLog LogLevel
L.LevelError (SerializableBlob -> UnstructuredLog)
-> SerializableBlob -> UnstructuredLog
forall a b. (a -> b) -> a -> b
$
          String -> SerializableBlob
forall a. IsString a => String -> a
fromString (String -> SerializableBlob) -> String -> SerializableBlob
forall a b. (a -> b) -> a -> b
$
            String
"Unexpected exception raised in websocket. Please report this as a bug: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
e
      SomeException -> m ()
forall (m :: * -> *) e a. (MonadBase IO m, Exception e) => e -> m a
throwIO SomeException
e

    shuttingDownReject :: RejectRequest
shuttingDownReject =
      Int
-> ByteString
-> [(CI ByteString, ByteString)]
-> ByteString
-> RejectRequest
WS.RejectRequest
        Int
503
        ByteString
"Service Unavailable"
        [(CI ByteString
"Retry-After", ByteString
"0")]
        ByteString
"Server is shutting down"

    onReject :: WSId -> RejectRequest -> m ()
onReject WSId
wsId RejectRequest
rejectRequest = do
      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RejectRequest -> IO ()
WS.rejectRequestWith PendingConnection
pendingConn RejectRequest
rejectRequest
      Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
ERejected Maybe WSEventInfo
forall a. Maybe a
Nothing

    onAccept :: WSConnInitTimer -> WSId -> AcceptWith a -> m ()
onAccept WSConnInitTimer
wsConnInitTimer WSId
wsId (AcceptWith a
a AcceptRequest
acceptWithParams WSConn a -> IO ()
keepAlive WSConn a -> IO ()
onJwtExpiry) = do
      Connection
conn <- IO Connection -> m Connection
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$ PendingConnection -> AcceptRequest -> IO Connection
WS.acceptRequestWith PendingConnection
pendingConn AcceptRequest
acceptWithParams
      Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
EAccepted Maybe WSEventInfo
forall a. Maybe a
Nothing
      TQueue WSQueueResponse
sendQ <- IO (TQueue WSQueueResponse) -> m (TQueue WSQueueResponse)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (TQueue WSQueueResponse)
forall a. IO (TQueue a)
STM.newTQueueIO
      let !wsConn :: WSConn a
wsConn = WSId
-> Logger Hasura
-> Connection
-> TQueue WSQueueResponse
-> a
-> WSConn a
forall a.
WSId
-> Logger Hasura
-> Connection
-> TQueue WSQueueResponse
-> a
-> WSConn a
WSConn WSId
wsId Logger Hasura
logger Connection
conn TQueue WSQueueResponse
sendQ a
a
      -- TODO there are many thunks here. Difficult to trace how much is retained, and
      --      how much of that would be shared anyway.
      --      Requires a fork of 'wai-websockets' and 'websockets', it looks like.
      --      Adding `package` stanzas with -Xstrict -XStrictData for those two packages
      --      helped, cutting the number of thunks approximately in half.
      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ String
String -> WSConn a -> IO ()
forall a. String -> a -> IO ()
$assertNFHere WSConn a
wsConn -- so we don't write thunks to mutable vars
      let whenAcceptingInsertConn :: m (ServerStatus a)
whenAcceptingInsertConn = IO (ServerStatus a) -> m (ServerStatus a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ServerStatus a) -> m (ServerStatus a))
-> IO (ServerStatus a) -> m (ServerStatus a)
forall a b. (a -> b) -> a -> b
$
            STM (ServerStatus a) -> IO (ServerStatus a)
forall a. STM a -> IO a
STM.atomically (STM (ServerStatus a) -> IO (ServerStatus a))
-> STM (ServerStatus a) -> IO (ServerStatus a)
forall a b. (a -> b) -> a -> b
$ do
              ServerStatus a
status <- TVar (ServerStatus a) -> STM (ServerStatus a)
forall a. TVar a -> STM a
STM.readTVar TVar (ServerStatus a)
serverStatus
              case ServerStatus a
status of
                ServerStatus a
ShuttingDown -> () -> STM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                AcceptingConns ConnMap a
connMap -> WSConn a -> WSId -> ConnMap a -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
STMMap.insert WSConn a
wsConn WSId
wsId ConnMap a
connMap
              ServerStatus a -> STM (ServerStatus a)
forall (m :: * -> *) a. Monad m => a -> m a
return ServerStatus a
status

      -- ensure we clean up connMap even if an unexpected exception is raised from our worker
      -- threads, or an async exception is raised somewhere in the body here:
      m (ServerStatus a)
-> (ServerStatus a -> m ()) -> (ServerStatus a -> m ()) -> m ()
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        m (ServerStatus a)
whenAcceptingInsertConn
        (WSConn a -> ServerStatus a -> m ()
onConnClose WSConn a
wsConn)
        ((ServerStatus a -> m ()) -> m ())
-> (ServerStatus a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
          ServerStatus a
ShuttingDown -> do
            -- Bad luck, we were in the process of shutting the server down but a new
            -- connection was accepted. Let's just close it politely
            WSConn a -> ByteString -> m ()
forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect WSConn a
wsConn ByteString
"shutting server down"
            OnCloseH m a
closeHandler WSConn a
wsConn
          AcceptingConns ConnMap a
_ -> do
            let rcv :: m ()
rcv = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  -- Process all messages serially (important!), in a separate thread:
                  ByteString
msg <-
                    IO ByteString -> m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$
                      -- Re-throw "receiveloop: resource vanished (Connection reset by peer)" :
                      --   https://github.com/yesodweb/wai/blob/master/warp/Network/Wai/Handler/Warp/Recv.hs#L112
                      -- as WS exception signaling cleanup below. It's not clear why exactly this gets
                      -- raised occasionally; I suspect an equivalent handler is missing from WS itself.
                      -- Regardless this should be safe:
                      (IOError -> Maybe ())
-> (() -> IO ByteString) -> IO ByteString -> IO ByteString
forall (m :: * -> *) e b a.
(MonadBaseControl IO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> (IOError -> Bool) -> IOError -> Maybe ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Bool
E.isResourceVanishedError) (\() -> ConnectionException -> IO ByteString
forall a e. Exception e => e -> a
throw ConnectionException
WS.ConnectionClosed) (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$
                        Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn
                  let message :: MessageDetails
message = SerializableBlob -> Int64 -> MessageDetails
MessageDetails (ByteString -> SerializableBlob
SB.fromLBS ByteString
msg) (ByteString -> Int64
BL.length ByteString
msg)
                  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId (MessageDetails -> WSEvent
EMessageReceived MessageDetails
message) Maybe WSEventInfo
forall a. Maybe a
Nothing
                  WSConn a -> ByteString -> WSSubProtocol -> m ()
messageHandler WSConn a
wsConn ByteString
msg WSSubProtocol
subProtocol

            let send :: m ()
send = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  WSQueueResponse ByteString
msg Maybe WSEventInfo
wsInfo <- IO WSQueueResponse -> m WSQueueResponse
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSQueueResponse -> m WSQueueResponse)
-> IO WSQueueResponse -> m WSQueueResponse
forall a b. (a -> b) -> a -> b
$ STM WSQueueResponse -> IO WSQueueResponse
forall a. STM a -> IO a
STM.atomically (STM WSQueueResponse -> IO WSQueueResponse)
-> STM WSQueueResponse -> IO WSQueueResponse
forall a b. (a -> b) -> a -> b
$ TQueue WSQueueResponse -> STM WSQueueResponse
forall a. TQueue a -> STM a
STM.readTQueue TQueue WSQueueResponse
sendQ
                  let message :: MessageDetails
message = SerializableBlob -> Int64 -> MessageDetails
MessageDetails (ByteString -> SerializableBlob
SB.fromLBS ByteString
msg) (ByteString -> Int64
BL.length ByteString
msg)
                  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn ByteString
msg
                  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId (MessageDetails -> WSEvent
EMessageSent MessageDetails
message) Maybe WSEventInfo
wsInfo

            -- withAsync lets us be very sure that if e.g. an async exception is raised while we're
            -- forking that the threads we launched will be cleaned up. See also below.
            m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync m ()
rcv ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
rcvRef -> do
              m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync m ()
send ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
sendRef -> do
                m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> IO ()
keepAlive WSConn a
wsConn) ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
keepAliveRef -> do
                  m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> IO ()
onJwtExpiry WSConn a
wsConn) ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
onJwtExpiryRef -> do
                    -- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws`
                    WSConnInitTimerStatus
timeoutStatus <- IO WSConnInitTimerStatus -> m WSConnInitTimerStatus
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSConnInitTimerStatus -> m WSConnInitTimerStatus)
-> IO WSConnInitTimerStatus -> m WSConnInitTimerStatus
forall a b. (a -> b) -> a -> b
$ WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState WSConnInitTimer
wsConnInitTimer
                    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WSConnInitTimerStatus
timeoutStatus WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
forall a. Eq a => a -> a -> Bool
== WSConnInitTimerStatus
Done Bool -> Bool -> Bool
&& WSSubProtocol
subProtocol WSSubProtocol -> WSSubProtocol -> Bool
forall a. Eq a => a -> a -> Bool
== WSSubProtocol
GraphQLWS) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
4408 ByteString
"Connection initialisation timed out"

                    -- terminates on WS.ConnectionException and JWT expiry
                    let waitOnRefs :: [Async ()]
waitOnRefs = [Async ()
keepAliveRef, Async ()
onJwtExpiryRef, Async ()
rcvRef, Async ()
sendRef]
                    -- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
                    -- case of async exceptions raised while blocking here:
                    m (Async (), ()) -> m (Either ConnectionException (Async (), ()))
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try ([Async ()] -> m (Async (), ())
forall (m :: * -> *) a.
(MonadBase IO m, Forall (Pure m)) =>
[Async a] -> m (Async a, a)
LA.waitAnyCancel [Async ()]
waitOnRefs) m (Either ConnectionException (Async (), ()))
-> (Either ConnectionException (Async (), ()) -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                      -- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt
                      -- exceptions; for now handle all ConnectionException by closing
                      -- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48
                      Left (ConnectionException
_ :: WS.ConnectionException) -> do
                        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
ECloseReceived Maybe WSEventInfo
forall a. Maybe a
Nothing
                      -- this will happen when jwt is expired
                      Right (Async (), ())
_ -> do
                        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
EJwtExpired Maybe WSEventInfo
forall a. Maybe a
Nothing

    onConnClose :: WSConn a -> ServerStatus a -> m ()
onConnClose WSConn a
wsConn = \case
      ServerStatus a
ShuttingDown -> () -> m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      AcceptingConns ConnMap a
connMap -> do
        IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ WSId -> ConnMap a -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
STMMap.delete (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) ConnMap a
connMap
        OnCloseH m a
closeHandler WSConn a
wsConn
        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
EClosed Maybe WSEventInfo
forall a. Maybe a
Nothing

shutdown :: WSServer a -> IO ()
shutdown :: WSServer a -> IO ()
shutdown (WSServer (L.Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog) TVar (ServerStatus a)
serverStatus) = do
  UnstructuredLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog (UnstructuredLog -> IO ()) -> UnstructuredLog -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> UnstructuredLog
L.debugT Text
"Shutting websockets server down"
  [(WSId, WSConn a)]
conns <- STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a. STM a -> IO a
STM.atomically (STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)])
-> STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a b. (a -> b) -> a -> b
$ do
    [(WSId, WSConn a)]
conns <- TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
forall a. TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
flushConnMap TVar (ServerStatus a)
serverStatus
    TVar (ServerStatus a) -> ServerStatus a -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar (ServerStatus a)
serverStatus ServerStatus a
forall a. ServerStatus a
ShuttingDown
    [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(WSId, WSConn a)]
conns
  (ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
forall a.
(ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
closeAllWith ((WSConn a -> ByteString -> IO ())
-> ByteString -> WSConn a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip WSConn a -> ByteString -> IO ()
forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect) ByteString
"shutting server down" [(WSId, WSConn a)]
conns