{-# LANGUAGE NondecreasingIndentation #-}
{-# LANGUAGE TemplateHaskell #-}

module Hasura.GraphQL.Transport.WebSocket.Server
  ( AcceptWith (AcceptWith),
    HasuraServerApp,
    MessageDetails (MessageDetails),
    MonadWSLog (..),
    OnConnH,
    WSActions (..),
    WSConn,
    WSErrorMessage (..),
    WSEvent (EMessageSent),
    WSEventInfo (WSEventInfo, _wseiEventType, _wseiOperationId, _wseiOperationName, _wseiParameterizedQueryHash, _wseiQueryExecutionTime, _wseiResponseSize),
    WSHandlers (WSHandlers),
    WSId,
    WSKeepAliveMessageAction,
    WSLog (WSLog),
    WSOnErrorMessageAction,
    WSQueueResponse (WSQueueResponse),
    WSServer (..),
    websocketConnectionReaper,
    closeConn,
    sendMsgAndCloseConn,
    createServerApp,
    createWSServer,
    closeAllConnectionsWithReason,
    getData,
    getRawWebSocketConnection,
    getWSId,
    mkWSServerErrorCode,
    sendMsg,
    shutdown,

    -- * exported for testing
    mkUnsafeWSId,
  )
where

import Control.Concurrent.Async qualified as A
import Control.Concurrent.Async.Lifted.Safe qualified as LA
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM (readTVarIO)
import Control.Concurrent.STM qualified as STM
import Control.Exception.Lifted
import Control.Monad.Trans.Control qualified as MC
import Data.Aeson qualified as J
import Data.Aeson.Casing qualified as J
import Data.Aeson.TH qualified as J
import Data.ByteString.Char8 qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.CaseInsensitive qualified as CI
import Data.HashSet qualified as Set
import Data.SerializableBlob qualified as SB
import Data.String
import Data.Text qualified as T
import Data.UUID qualified as UUID
import Data.UUID.V4 qualified as UUID
import Data.Word (Word16)
import GHC.AssertNF.CPP
import GHC.Int (Int64)
import Hasura.GraphQL.ParameterizedQueryHash (ParameterizedQueryHash)
import Hasura.GraphQL.Schema.NamingCase (hasNamingConventionChanged)
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.GraphQL.Transport.WebSocket.Protocol
import Hasura.Logging qualified as L
import Hasura.Prelude
import Hasura.RQL.Types.Common (MetricsConfig (..), SQLGenCtx (..))
import Hasura.RQL.Types.NamingCase (NamingCase (..))
import Hasura.RQL.Types.SchemaCache
import Hasura.Server.Auth (AuthMode, compareAuthMode)
import Hasura.Server.Cors (CorsPolicy)
import Hasura.Server.Init.Config (AllowListStatus (..), WSConnectionInitTimeout (..))
import Hasura.Server.Prometheus
  ( PrometheusMetrics (..),
  )
import Hasura.Server.Types (ExperimentalFeature (..))
import ListT qualified
import Network.Wai.Extended (IpAddress)
import Network.Wai.Handler.Warp qualified as Warp
import Network.WebSockets qualified as WS
import Refined (unrefine)
import StmContainers.Map qualified as STMMap
import System.IO.Error qualified as E
import System.Metrics.Prometheus.Counter qualified as Prometheus.Counter
import System.Metrics.Prometheus.Histogram qualified as Prometheus.Histogram
import System.TimeManager qualified as TM

newtype WSId = WSId {WSId -> UUID
unWSId :: UUID.UUID}
  deriving (Int -> WSId -> ShowS
[WSId] -> ShowS
WSId -> String
(Int -> WSId -> ShowS)
-> (WSId -> String) -> ([WSId] -> ShowS) -> Show WSId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WSId -> ShowS
showsPrec :: Int -> WSId -> ShowS
$cshow :: WSId -> String
show :: WSId -> String
$cshowList :: [WSId] -> ShowS
showList :: [WSId] -> ShowS
Show, WSId -> WSId -> Bool
(WSId -> WSId -> Bool) -> (WSId -> WSId -> Bool) -> Eq WSId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: WSId -> WSId -> Bool
== :: WSId -> WSId -> Bool
$c/= :: WSId -> WSId -> Bool
/= :: WSId -> WSId -> Bool
Eq, Eq WSId
Eq WSId -> (Int -> WSId -> Int) -> (WSId -> Int) -> Hashable WSId
Int -> WSId -> Int
WSId -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> WSId -> Int
hashWithSalt :: Int -> WSId -> Int
$chash :: WSId -> Int
hash :: WSId -> Int
Hashable)

mkUnsafeWSId :: UUID.UUID -> WSId
mkUnsafeWSId :: UUID -> WSId
mkUnsafeWSId = UUID -> WSId
WSId

instance J.ToJSON WSId where
  toJSON :: WSId -> Value
toJSON (WSId UUID
uuid) =
    Text -> Value
forall a. ToJSON a => a -> Value
J.toJSON (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ UUID -> Text
UUID.toText UUID
uuid

-- | Websocket message and other details
data MessageDetails = MessageDetails
  { MessageDetails -> SerializableBlob
_mdMessage :: !SB.SerializableBlob,
    MessageDetails -> Int64
_mdMessageSize :: !Int64
  }
  deriving (Int -> MessageDetails -> ShowS
[MessageDetails] -> ShowS
MessageDetails -> String
(Int -> MessageDetails -> ShowS)
-> (MessageDetails -> String)
-> ([MessageDetails] -> ShowS)
-> Show MessageDetails
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MessageDetails -> ShowS
showsPrec :: Int -> MessageDetails -> ShowS
$cshow :: MessageDetails -> String
show :: MessageDetails -> String
$cshowList :: [MessageDetails] -> ShowS
showList :: [MessageDetails] -> ShowS
Show)

$(J.deriveToJSON hasuraJSON ''MessageDetails)

data WSEvent
  = EConnectionRequest
  | EAccepted
  | ERejected
  | EMessageReceived !MessageDetails
  | EMessageSent !MessageDetails
  | EJwtExpired
  | ECloseReceived
  | ECloseSent !SB.SerializableBlob
  | EClosed
  deriving (Int -> WSEvent -> ShowS
[WSEvent] -> ShowS
WSEvent -> String
(Int -> WSEvent -> ShowS)
-> (WSEvent -> String) -> ([WSEvent] -> ShowS) -> Show WSEvent
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WSEvent -> ShowS
showsPrec :: Int -> WSEvent -> ShowS
$cshow :: WSEvent -> String
show :: WSEvent -> String
$cshowList :: [WSEvent] -> ShowS
showList :: [WSEvent] -> ShowS
Show)

$( J.deriveToJSON
     J.defaultOptions
       { J.constructorTagModifier = J.snakeCase . drop 1,
         J.sumEncoding = J.TaggedObject "type" "detail"
       }
     ''WSEvent
 )

-- extra websocket event info
data WSEventInfo = WSEventInfo
  { WSEventInfo -> Maybe ServerMsgType
_wseiEventType :: !(Maybe ServerMsgType),
    WSEventInfo -> Maybe OperationId
_wseiOperationId :: !(Maybe OperationId),
    WSEventInfo -> Maybe OperationName
_wseiOperationName :: !(Maybe OperationName),
    WSEventInfo -> Maybe Double
_wseiQueryExecutionTime :: !(Maybe Double),
    WSEventInfo -> Maybe Int64
_wseiResponseSize :: !(Maybe Int64),
    WSEventInfo -> Maybe ParameterizedQueryHash
_wseiParameterizedQueryHash :: !(Maybe ParameterizedQueryHash)
  }
  deriving (Int -> WSEventInfo -> ShowS
[WSEventInfo] -> ShowS
WSEventInfo -> String
(Int -> WSEventInfo -> ShowS)
-> (WSEventInfo -> String)
-> ([WSEventInfo] -> ShowS)
-> Show WSEventInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WSEventInfo -> ShowS
showsPrec :: Int -> WSEventInfo -> ShowS
$cshow :: WSEventInfo -> String
show :: WSEventInfo -> String
$cshowList :: [WSEventInfo] -> ShowS
showList :: [WSEventInfo] -> ShowS
Show, WSEventInfo -> WSEventInfo -> Bool
(WSEventInfo -> WSEventInfo -> Bool)
-> (WSEventInfo -> WSEventInfo -> Bool) -> Eq WSEventInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: WSEventInfo -> WSEventInfo -> Bool
== :: WSEventInfo -> WSEventInfo -> Bool
$c/= :: WSEventInfo -> WSEventInfo -> Bool
/= :: WSEventInfo -> WSEventInfo -> Bool
Eq)

$( J.deriveToJSON
     J.defaultOptions
       { J.fieldLabelModifier = J.snakeCase . drop 5,
         J.omitNothingFields = True
       }
     ''WSEventInfo
 )

data WSLog = WSLog
  { WSLog -> WSId
_wslWebsocketId :: !WSId,
    WSLog -> WSEvent
_wslEvent :: !WSEvent,
    WSLog -> Maybe WSEventInfo
_wslMetadata :: !(Maybe WSEventInfo)
  }
  deriving (Int -> WSLog -> ShowS
[WSLog] -> ShowS
WSLog -> String
(Int -> WSLog -> ShowS)
-> (WSLog -> String) -> ([WSLog] -> ShowS) -> Show WSLog
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WSLog -> ShowS
showsPrec :: Int -> WSLog -> ShowS
$cshow :: WSLog -> String
show :: WSLog -> String
$cshowList :: [WSLog] -> ShowS
showList :: [WSLog] -> ShowS
Show)

$( J.deriveToJSON
     J.defaultOptions
       { J.fieldLabelModifier = J.snakeCase . drop 4,
         J.omitNothingFields = True
       }
     ''WSLog
 )

class (Monad m) => MonadWSLog m where
  -- | Takes WS server log data and logs it
  -- logWSServer
  logWSLog :: L.Logger L.Hasura -> WSLog -> m ()

instance (MonadWSLog m) => MonadWSLog (ExceptT e m) where
  logWSLog :: Logger Hasura -> WSLog -> ExceptT e m ()
logWSLog Logger Hasura
l WSLog
ws = m () -> ExceptT e m ()
forall (m :: * -> *) a. Monad m => m a -> ExceptT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ExceptT e m ()) -> m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
l WSLog
ws

instance (MonadWSLog m) => MonadWSLog (ReaderT r m) where
  logWSLog :: Logger Hasura -> WSLog -> ReaderT r m ()
logWSLog Logger Hasura
l WSLog
ws = m () -> ReaderT r m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ()) -> m () -> ReaderT r m ()
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
l WSLog
ws

instance L.ToEngineLog WSLog L.Hasura where
  toEngineLog :: WSLog -> (LogLevel, EngineLogType Hasura, Value)
toEngineLog WSLog
wsLog =
    (LogLevel
L.LevelDebug, InternalLogTypes -> EngineLogType Hasura
L.ELTInternal InternalLogTypes
L.ILTWsServer, WSLog -> Value
forall a. ToJSON a => a -> Value
J.toJSON WSLog
wsLog)

data WSReaperThreadLog = WSReaperThreadLog
  { WSReaperThreadLog -> Text
_wrtlMessage :: Text
  }
  deriving (Int -> WSReaperThreadLog -> ShowS
[WSReaperThreadLog] -> ShowS
WSReaperThreadLog -> String
(Int -> WSReaperThreadLog -> ShowS)
-> (WSReaperThreadLog -> String)
-> ([WSReaperThreadLog] -> ShowS)
-> Show WSReaperThreadLog
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> WSReaperThreadLog -> ShowS
showsPrec :: Int -> WSReaperThreadLog -> ShowS
$cshow :: WSReaperThreadLog -> String
show :: WSReaperThreadLog -> String
$cshowList :: [WSReaperThreadLog] -> ShowS
showList :: [WSReaperThreadLog] -> ShowS
Show)

instance L.ToEngineLog WSReaperThreadLog L.Hasura where
  toEngineLog :: WSReaperThreadLog -> (LogLevel, EngineLogType Hasura, Value)
toEngineLog (WSReaperThreadLog Text
message) =
    (LogLevel
L.LevelInfo, InternalLogTypes -> EngineLogType Hasura
L.ELTInternal InternalLogTypes
L.ILTWsServer, Text -> Value
forall a. ToJSON a => a -> Value
J.toJSON Text
message)

data WSQueueResponse = WSQueueResponse
  { WSQueueResponse -> ByteString
_wsqrMessage :: BL.ByteString,
    -- | extra metadata that we use for other actions, such as print log
    -- we don't want to inlcude them into websocket message payload
    WSQueueResponse -> Maybe WSEventInfo
_wsqrEventInfo :: Maybe WSEventInfo,
    -- | Timer to compute the time for which the websocket message
    --   remains queued.
    WSQueueResponse -> IO DiffTime
_wsqrTimer :: IO DiffTime
  }

data WSConn a = WSConn
  { forall a. WSConn a -> WSId
_wcConnId :: !WSId,
    forall a. WSConn a -> Logger Hasura
_wcLogger :: !(L.Logger L.Hasura),
    forall a. WSConn a -> Connection
_wcConnRaw :: !WS.Connection,
    forall a. WSConn a -> TQueue WSQueueResponse
_wcSendQ :: !(STM.TQueue WSQueueResponse),
    forall a. WSConn a -> a
_wcExtraData :: !a
  }

getRawWebSocketConnection :: WSConn a -> WS.Connection
getRawWebSocketConnection :: forall a. WSConn a -> Connection
getRawWebSocketConnection = WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw

getData :: WSConn a -> a
getData :: forall a. WSConn a -> a
getData = WSConn a -> a
forall a. WSConn a -> a
_wcExtraData

getWSId :: WSConn a -> WSId
getWSId :: forall a. WSConn a -> WSId
getWSId = WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId

closeConn :: WSConn a -> BL.ByteString -> IO ()
closeConn :: forall a. WSConn a -> ByteString -> IO ()
closeConn WSConn a
wsConn = WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
1000 -- 1000 is "normal close"

-- | Closes a connection with code 1012, which means "Server is restarting"
-- good clients will implement a retry logic with a backoff of a few seconds
forceConnReconnect :: (MonadIO m) => WSConn a -> BL.ByteString -> m ()
forceConnReconnect :: forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect WSConn a
wsConn ByteString
bs = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
1012 ByteString
bs

closeConnWithCode :: WSConn a -> Word16 -> BL.ByteString -> IO ()
closeConnWithCode :: forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
code ByteString
bs = do
  ((\Logger Hasura
x -> Logger Hasura
-> forall a (m :: * -> *).
   (ToEngineLog a Hasura, MonadIO m) =>
   a -> m ()
forall impl.
Logger impl
-> forall a (m :: * -> *).
   (ToEngineLog a impl, MonadIO m) =>
   a -> m ()
L.unLogger Logger Hasura
x) (Logger Hasura -> WSLog -> IO ())
-> (WSConn a -> Logger Hasura) -> WSConn a -> WSLog -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WSConn a -> Logger Hasura
forall a. WSConn a -> Logger Hasura
_wcLogger) WSConn a
wsConn
    (WSLog -> IO ()) -> WSLog -> IO ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) (SerializableBlob -> WSEvent
ECloseSent (SerializableBlob -> WSEvent) -> SerializableBlob -> WSEvent
forall a b. (a -> b) -> a -> b
$ ByteString -> SerializableBlob
SB.fromLBS ByteString
bs) Maybe WSEventInfo
forall a. Maybe a
Nothing
  Connection -> Word16 -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
WS.sendCloseCode (WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw WSConn a
wsConn) Word16
code ByteString
bs

sendMsgAndCloseConn :: WSConn a -> Word16 -> BL.ByteString -> ServerMsg -> IO ()
sendMsgAndCloseConn :: forall a. WSConn a -> Word16 -> ByteString -> ServerMsg -> IO ()
sendMsgAndCloseConn WSConn a
wsConn Word16
errCode ByteString
bs ServerMsg
serverErr = do
  Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData (WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw WSConn a
wsConn) (ServerMsg -> ByteString
encodeServerMsg ServerMsg
serverErr)
  Connection -> Word16 -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
WS.sendCloseCode (WSConn a -> Connection
forall a. WSConn a -> Connection
_wcConnRaw WSConn a
wsConn) Word16
errCode ByteString
bs

-- writes to a queue instead of the raw connection
-- so that sendMsg doesn't block
sendMsg :: WSConn a -> WSQueueResponse -> IO ()
sendMsg :: forall a. WSConn a -> WSQueueResponse -> IO ()
sendMsg WSConn a
wsConn !WSQueueResponse
resp = do
  $String
String -> WSQueueResponse -> IO ()
forall a. String -> a -> IO ()
assertNFHere WSQueueResponse
resp -- so we don't write thunks to mutable vars
  STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue WSQueueResponse -> WSQueueResponse -> STM ()
forall a. TQueue a -> a -> STM ()
STM.writeTQueue (WSConn a -> TQueue WSQueueResponse
forall a. WSConn a -> TQueue WSQueueResponse
_wcSendQ WSConn a
wsConn) WSQueueResponse
resp

type ConnMap a = STMMap.Map WSId (WSConn a)

data ServerStatus a
  = AcceptingConns !(ConnMap a)
  | ShuttingDown

data WSServer a = WSServer
  { forall a. WSServer a -> Logger Hasura
_wssLogger :: L.Logger L.Hasura,
    -- | Keep track of the security sensitive user configuration to perform
    -- maintenance actions
    forall a. WSServer a -> TVar SecuritySensitiveUserConfig
_wssSecuritySensitiveUserConfig :: STM.TVar SecuritySensitiveUserConfig,
    -- | See e.g. createServerApp.onAccept for how we use STM to preserve consistency
    forall a. WSServer a -> TVar (ServerStatus a)
_wssStatus :: STM.TVar (ServerStatus a)
  }

-- These are security sensitive user configuration. That is, if any of the
-- following config changes, we need to perform maintenance actions like closing
-- all websocket connections
data SecuritySensitiveUserConfig = SecuritySensitiveUserConfig
  { SecuritySensitiveUserConfig -> AuthMode
ssucAuthMode :: AuthMode,
    SecuritySensitiveUserConfig -> AllowListStatus
ssucEnableAllowlist :: AllowListStatus,
    SecuritySensitiveUserConfig -> InlinedAllowlist
ssucAllowlist :: InlinedAllowlist,
    SecuritySensitiveUserConfig -> CorsPolicy
ssucCorsPolicy :: CorsPolicy,
    SecuritySensitiveUserConfig -> SQLGenCtx
ssucSQLGenCtx :: SQLGenCtx,
    SecuritySensitiveUserConfig -> HashSet ExperimentalFeature
ssucExperimentalFeatures :: Set.HashSet ExperimentalFeature,
    SecuritySensitiveUserConfig -> NamingCase
ssucDefaultNamingCase :: NamingCase
  }

createWSServer :: AuthMode -> AllowListStatus -> InlinedAllowlist -> CorsPolicy -> SQLGenCtx -> Set.HashSet ExperimentalFeature -> NamingCase -> L.Logger L.Hasura -> STM.STM (WSServer a)
createWSServer :: forall a.
AuthMode
-> AllowListStatus
-> InlinedAllowlist
-> CorsPolicy
-> SQLGenCtx
-> HashSet ExperimentalFeature
-> NamingCase
-> Logger Hasura
-> STM (WSServer a)
createWSServer AuthMode
authMode AllowListStatus
enableAllowlist InlinedAllowlist
allowlist CorsPolicy
corsPolicy SQLGenCtx
sqlGenCtx HashSet ExperimentalFeature
experimentalFeatured NamingCase
defaultNamingCase Logger Hasura
logger = do
  Map WSId (WSConn a)
connMap <- STM (Map WSId (WSConn a))
forall key value. STM (Map key value)
STMMap.new
  TVar SecuritySensitiveUserConfig
userConfRef <- SecuritySensitiveUserConfig
-> STM (TVar SecuritySensitiveUserConfig)
forall a. a -> STM (TVar a)
STM.newTVar (SecuritySensitiveUserConfig
 -> STM (TVar SecuritySensitiveUserConfig))
-> SecuritySensitiveUserConfig
-> STM (TVar SecuritySensitiveUserConfig)
forall a b. (a -> b) -> a -> b
$ AuthMode
-> AllowListStatus
-> InlinedAllowlist
-> CorsPolicy
-> SQLGenCtx
-> HashSet ExperimentalFeature
-> NamingCase
-> SecuritySensitiveUserConfig
SecuritySensitiveUserConfig AuthMode
authMode AllowListStatus
enableAllowlist InlinedAllowlist
allowlist CorsPolicy
corsPolicy SQLGenCtx
sqlGenCtx HashSet ExperimentalFeature
experimentalFeatured NamingCase
defaultNamingCase
  TVar (ServerStatus a)
serverStatus <- ServerStatus a -> STM (TVar (ServerStatus a))
forall a. a -> STM (TVar a)
STM.newTVar (Map WSId (WSConn a) -> ServerStatus a
forall a. ConnMap a -> ServerStatus a
AcceptingConns Map WSId (WSConn a)
connMap)
  WSServer a -> STM (WSServer a)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (WSServer a -> STM (WSServer a)) -> WSServer a -> STM (WSServer a)
forall a b. (a -> b) -> a -> b
$ Logger Hasura
-> TVar SecuritySensitiveUserConfig
-> TVar (ServerStatus a)
-> WSServer a
forall a.
Logger Hasura
-> TVar SecuritySensitiveUserConfig
-> TVar (ServerStatus a)
-> WSServer a
WSServer Logger Hasura
logger TVar SecuritySensitiveUserConfig
userConfRef TVar (ServerStatus a)
serverStatus

closeAllWith ::
  (BL.ByteString -> WSConn a -> IO ()) ->
  BL.ByteString ->
  [(WSId, WSConn a)] ->
  IO ()
closeAllWith :: forall a.
(ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
closeAllWith ByteString -> WSConn a -> IO ()
closer ByteString
msg [(WSId, WSConn a)]
conns =
  IO [()] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO [()] -> IO ()) -> IO [()] -> IO ()
forall a b. (a -> b) -> a -> b
$ ((WSId, WSConn a) -> IO ()) -> [(WSId, WSConn a)] -> IO [()]
forall (t :: * -> *) a b.
Traversable t =>
(a -> IO b) -> t a -> IO (t b)
A.mapConcurrently (ByteString -> WSConn a -> IO ()
closer ByteString
msg (WSConn a -> IO ())
-> ((WSId, WSConn a) -> WSConn a) -> (WSId, WSConn a) -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WSId, WSConn a) -> WSConn a
forall a b. (a, b) -> b
snd) [(WSId, WSConn a)]
conns

closeAllConnectionsWithReason ::
  WSServer a ->
  String ->
  BL.ByteString ->
  (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig) ->
  IO ()
closeAllConnectionsWithReason :: forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason (WSServer (L.Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog) TVar SecuritySensitiveUserConfig
userConfRef TVar (ServerStatus a)
serverStatus) String
logMsg ByteString
reason SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig
updateConf = do
  WSReaperThreadLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog
    (WSReaperThreadLog -> IO ()) -> WSReaperThreadLog -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> WSReaperThreadLog
WSReaperThreadLog
    (Text -> WSReaperThreadLog) -> Text -> WSReaperThreadLog
forall a b. (a -> b) -> a -> b
$ String -> Text
forall a. IsString a => String -> a
fromString
    (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ String
logMsg
  [(WSId, WSConn a)]
conns <- STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a. STM a -> IO a
STM.atomically (STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)])
-> STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a b. (a -> b) -> a -> b
$ do
    TVar SecuritySensitiveUserConfig
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
STM.modifyTVar' TVar SecuritySensitiveUserConfig
userConfRef SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig
updateConf
    TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
forall a. TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
flushConnMap TVar (ServerStatus a)
serverStatus
  (ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
forall a.
(ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
closeAllWith ((WSConn a -> ByteString -> IO ())
-> ByteString -> WSConn a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip WSConn a -> ByteString -> IO ()
forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect) ByteString
reason [(WSId, WSConn a)]
conns

-- | Resets the current connections map to an empty one if the server is
-- running and returns the list of connections that were in the map
-- before flushing it.
flushConnMap :: STM.TVar (ServerStatus a) -> STM.STM [(WSId, WSConn a)]
flushConnMap :: forall a. TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
flushConnMap TVar (ServerStatus a)
serverStatus = do
  ServerStatus a
status <- TVar (ServerStatus a) -> STM (ServerStatus a)
forall a. TVar a -> STM a
STM.readTVar TVar (ServerStatus a)
serverStatus
  case ServerStatus a
status of
    AcceptingConns ConnMap a
connMap -> do
      [(WSId, WSConn a)]
conns <- ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)]
forall (m :: * -> *) a. Monad m => ListT m a -> m [a]
ListT.toList (ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)])
-> ListT STM (WSId, WSConn a) -> STM [(WSId, WSConn a)]
forall a b. (a -> b) -> a -> b
$ ConnMap a -> ListT STM (WSId, WSConn a)
forall key value. Map key value -> ListT STM (key, value)
STMMap.listT ConnMap a
connMap
      ConnMap a -> STM ()
forall key value. Map key value -> STM ()
STMMap.reset ConnMap a
connMap
      [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return [(WSId, WSConn a)]
conns
    ServerStatus a
ShuttingDown -> [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return []

data AcceptWith a = AcceptWith
  { forall a. AcceptWith a -> a
_awData :: !a,
    forall a. AcceptWith a -> AcceptRequest
_awReq :: !WS.AcceptRequest,
    forall a. AcceptWith a -> WSConn a -> IO ()
_awKeepAlive :: !(WSConn a -> IO ()),
    forall a. AcceptWith a -> WSConn a -> IO ()
_awOnJwtExpiry :: !(WSConn a -> IO ())
  }

-- | These set of functions or message handlers is used by the
--   server while communicating with the client. They are particularly
--   useful for the case when the messages being sent to the client
--   are different for each of the sub-protocol(s) supported by the server.
type WSKeepAliveMessageAction a = WSConn a -> IO ()

type WSPostExecErrMessageAction a = WSConn a -> OperationId -> GQExecError -> IO ()

type WSOnErrorMessageAction a = WSConn a -> ConnErrMsg -> WSErrorMessage -> IO ()

type WSCloseConnAction a = WSConn a -> OperationId -> String -> IO ()

-- | Used for specific actions within the `onConn` and `onMessage` handlers
data WSActions a = WSActions
  { forall a. WSActions a -> WSPostExecErrMessageAction a
_wsaPostExecErrMessageAction :: !(WSPostExecErrMessageAction a),
    forall a. WSActions a -> WSOnErrorMessageAction a
_wsaOnErrorMessageAction :: !(WSOnErrorMessageAction a),
    forall a. WSActions a -> WSCloseConnAction a
_wsaConnectionCloseAction :: !(WSCloseConnAction a),
    -- | NOTE: keep alive action was made redundant because we need to send this message
    -- after the connection has been successfully established after `connection_init`
    forall a. WSActions a -> WSKeepAliveMessageAction a
_wsaKeepAliveAction :: !(WSKeepAliveMessageAction a),
    forall a. WSActions a -> DataMsg -> ServerMsg
_wsaGetDataMessageType :: !(DataMsg -> ServerMsg),
    forall a. WSActions a -> AcceptRequest
_wsaAcceptRequest :: !WS.AcceptRequest,
    forall a. WSActions a -> [Encoding] -> Encoding
_wsaErrorMsgFormat :: !([J.Encoding] -> J.Encoding)
  }

data WSErrorMessage = ClientMessageParseFailed | ConnInitFailed

mkWSServerErrorCode :: WSSubProtocol -> WSErrorMessage -> ConnErrMsg -> ServerErrorCode
mkWSServerErrorCode :: WSSubProtocol -> WSErrorMessage -> ConnErrMsg -> ServerErrorCode
mkWSServerErrorCode WSSubProtocol
subProtocol WSErrorMessage
errorMessage ConnErrMsg
connErrMsg = case WSErrorMessage
errorMessage of
  WSErrorMessage
ClientMessageParseFailed -> (String -> ServerErrorCode
GenericError4400 (String -> ServerErrorCode) -> String -> ServerErrorCode
forall a b. (a -> b) -> a -> b
$ (String
"Parsing client message failed: ") String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (Text -> String
T.unpack (Text -> String) -> (ConnErrMsg -> Text) -> ConnErrMsg -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnErrMsg -> Text
unConnErrMsg (ConnErrMsg -> String) -> ConnErrMsg -> String
forall a b. (a -> b) -> a -> b
$ ConnErrMsg
connErrMsg))
  WSErrorMessage
ConnInitFailed -> case WSSubProtocol
subProtocol of
    WSSubProtocol
Apollo -> (String -> ServerErrorCode
GenericError4400 (String -> ServerErrorCode) -> String -> ServerErrorCode
forall a b. (a -> b) -> a -> b
$ (String
"Connection initialization failed: ") String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (Text -> String
T.unpack (Text -> String) -> (ConnErrMsg -> Text) -> ConnErrMsg -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnErrMsg -> Text
unConnErrMsg (ConnErrMsg -> String) -> ConnErrMsg -> String
forall a b. (a -> b) -> a -> b
$ ConnErrMsg
connErrMsg))
    WSSubProtocol
GraphQLWS -> ServerErrorCode
Forbidden4403

type OnConnH m a = WSId -> WS.RequestHead -> IpAddress -> WSActions a -> m (Either WS.RejectRequest (AcceptWith a))

-- type OnMessageH m a = WSConn a -> BL.ByteString -> WSActions a -> m ()

type OnCloseH m a = WSConn a -> m ()

-- | aka generalized 'WS.ServerApp' over @m@, which takes an IPAddress
type HasuraServerApp m = IpAddress -> WS.PendingConnection -> m ()

-- | NOTE: The types of `_hOnConn` and `_hOnMessage` were updated from `OnConnH` and `OnMessageH`
-- because we needed to pass the subprotcol here to these methods to eventually get to `OnConnH` and `OnMessageH`.
-- Please see `createServerApp` to get a better understanding of how these handlers are used.
data WSHandlers m a = WSHandlers
  { forall (m :: * -> *) a.
WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
_hOnConn :: (WSId -> WS.RequestHead -> IpAddress -> WSSubProtocol -> m (Either WS.RejectRequest (AcceptWith a))),
    forall (m :: * -> *) a.
WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
_hOnMessage :: (WSConn a -> BL.ByteString -> WSSubProtocol -> m ()),
    forall (m :: * -> *) a. WSHandlers m a -> OnCloseH m a
_hOnClose :: OnCloseH m a
  }

-- | The background thread responsible for closing all websocket connections
-- when security sensitive user configuration changes. It checks for changes in
-- the auth mode, allowlist, cors config, stringify num, dangerous boolean collapse,
-- stringify big query numeric, experimental features and invalidates/closes all
-- connections if there are any changes.
websocketConnectionReaper :: IO (AuthMode, AllowListStatus, CorsPolicy, SQLGenCtx, Set.HashSet ExperimentalFeature, NamingCase) -> IO SchemaCache -> WSServer a -> IO Void
websocketConnectionReaper :: forall a.
IO
  (AuthMode, AllowListStatus, CorsPolicy, SQLGenCtx,
   HashSet ExperimentalFeature, NamingCase)
-> IO SchemaCache -> WSServer a -> IO Void
websocketConnectionReaper IO
  (AuthMode, AllowListStatus, CorsPolicy, SQLGenCtx,
   HashSet ExperimentalFeature, NamingCase)
getLatestConfig IO SchemaCache
getSchemaCache ws :: WSServer a
ws@(WSServer Logger Hasura
_ TVar SecuritySensitiveUserConfig
userConfRef TVar (ServerStatus a)
_) =
  IO () -> IO Void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO Void) -> IO () -> IO Void
forall a b. (a -> b) -> a -> b
$ do
    (AuthMode
currAuthMode, AllowListStatus
currEnableAllowlist, CorsPolicy
currCorsPolicy, SQLGenCtx
currSqlGenCtx, HashSet ExperimentalFeature
currExperimentalFeatures, NamingCase
currDefaultNamingCase) <- IO
  (AuthMode, AllowListStatus, CorsPolicy, SQLGenCtx,
   HashSet ExperimentalFeature, NamingCase)
getLatestConfig
    InlinedAllowlist
currAllowlist <- SchemaCache -> InlinedAllowlist
scAllowlist (SchemaCache -> InlinedAllowlist)
-> IO SchemaCache -> IO InlinedAllowlist
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO SchemaCache
getSchemaCache
    SecuritySensitiveUserConfig AuthMode
prevAuthMode AllowListStatus
prevEnableAllowlist InlinedAllowlist
prevAllowlist CorsPolicy
prevCorsPolicy SQLGenCtx
prevSqlGenCtx HashSet ExperimentalFeature
prevExperimentalFeatures NamingCase
prevDefaultNamingCase <- TVar SecuritySensitiveUserConfig -> IO SecuritySensitiveUserConfig
forall a. TVar a -> IO a
readTVarIO TVar SecuritySensitiveUserConfig
userConfRef
    -- check and close all connections if required
    (AuthMode, AuthMode)
-> (CorsPolicy, CorsPolicy)
-> (AllowListStatus, AllowListStatus)
-> (InlinedAllowlist, InlinedAllowlist)
-> (SQLGenCtx, SQLGenCtx)
-> (HashSet ExperimentalFeature, HashSet ExperimentalFeature)
-> (NamingCase, NamingCase)
-> IO ()
checkAndReapConnections
      (AuthMode
currAuthMode, AuthMode
prevAuthMode)
      (CorsPolicy
currCorsPolicy, CorsPolicy
prevCorsPolicy)
      (AllowListStatus
currEnableAllowlist, AllowListStatus
prevEnableAllowlist)
      (InlinedAllowlist
currAllowlist, InlinedAllowlist
prevAllowlist)
      (SQLGenCtx
currSqlGenCtx, SQLGenCtx
prevSqlGenCtx)
      (HashSet ExperimentalFeature
currExperimentalFeatures, HashSet ExperimentalFeature
prevExperimentalFeatures)
      (NamingCase
currDefaultNamingCase, NamingCase
prevDefaultNamingCase)
    DiffTime -> IO ()
sleep (DiffTime -> IO ()) -> DiffTime -> IO ()
forall a b. (a -> b) -> a -> b
$ Seconds -> DiffTime
seconds Seconds
1
  where
    -- Close all connections based on -
    -- if CorsPolicy changed -> close
    -- if AuthMode changed -> close
    -- if AllowlistEnabled -> enabled from disabled -> close
    -- if AllowlistEnabled -> allowlist collection changed -> close
    -- if HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES  changed -> close
    -- if HASURA_GRAPHQL_V1_BOOLEAN_NULL_COLLAPSE changed -> close
    -- if 'bigquery_string_numeric_input', 'hide_aggregation_predicates', 'hide_stream_fields' values added/remove from experimental features -> close
    -- if naming convention changes -> close
    checkAndReapConnections :: (AuthMode, AuthMode)
-> (CorsPolicy, CorsPolicy)
-> (AllowListStatus, AllowListStatus)
-> (InlinedAllowlist, InlinedAllowlist)
-> (SQLGenCtx, SQLGenCtx)
-> (HashSet ExperimentalFeature, HashSet ExperimentalFeature)
-> (NamingCase, NamingCase)
-> IO ()
checkAndReapConnections
      (AuthMode
currAuthMode, AuthMode
prevAuthMode)
      (CorsPolicy
currCorsPolicy, CorsPolicy
prevCorsPolicy)
      (AllowListStatus
currEnableAllowlist, AllowListStatus
prevEnableAllowlist)
      (InlinedAllowlist
currAllowlist, InlinedAllowlist
prevAllowlist)
      (SQLGenCtx
currSqlGenCtx, SQLGenCtx
prevSqlGenCtx)
      (HashSet ExperimentalFeature
currExperimentalFeatures, HashSet ExperimentalFeature
prevExperimentalFeatures)
      (NamingCase
currDefaultNamingCase, NamingCase
prevDefaultNamingCase) = do
        Bool
hasAuthModeChanged <- Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AuthMode -> AuthMode -> IO Bool
compareAuthMode AuthMode
currAuthMode AuthMode
prevAuthMode
        let hasCorsPolicyChanged :: Bool
hasCorsPolicyChanged = CorsPolicy
currCorsPolicy CorsPolicy -> CorsPolicy -> Bool
forall a. Eq a => a -> a -> Bool
/= CorsPolicy
prevCorsPolicy
            hasAllowlistEnabled :: Bool
hasAllowlistEnabled = AllowListStatus
prevEnableAllowlist AllowListStatus -> AllowListStatus -> Bool
forall a. Eq a => a -> a -> Bool
== AllowListStatus
AllowListDisabled Bool -> Bool -> Bool
&& AllowListStatus
currEnableAllowlist AllowListStatus -> AllowListStatus -> Bool
forall a. Eq a => a -> a -> Bool
== AllowListStatus
AllowListEnabled
            hasAllowlistUpdated :: Bool
hasAllowlistUpdated =
              (AllowListStatus
prevEnableAllowlist AllowListStatus -> AllowListStatus -> Bool
forall a. Eq a => a -> a -> Bool
== AllowListStatus
AllowListEnabled Bool -> Bool -> Bool
&& AllowListStatus
currEnableAllowlist AllowListStatus -> AllowListStatus -> Bool
forall a. Eq a => a -> a -> Bool
== AllowListStatus
AllowListEnabled) Bool -> Bool -> Bool
&& (InlinedAllowlist
currAllowlist InlinedAllowlist -> InlinedAllowlist -> Bool
forall a. Eq a => a -> a -> Bool
/= InlinedAllowlist
prevAllowlist)
            hasStringifyNumChanged :: Bool
hasStringifyNumChanged = SQLGenCtx -> StringifyNumbers
stringifyNum SQLGenCtx
currSqlGenCtx StringifyNumbers -> StringifyNumbers -> Bool
forall a. Eq a => a -> a -> Bool
/= SQLGenCtx -> StringifyNumbers
stringifyNum SQLGenCtx
prevSqlGenCtx
            hasDangerousBooleanCollapseChanged :: Bool
hasDangerousBooleanCollapseChanged = SQLGenCtx -> DangerouslyCollapseBooleans
dangerousBooleanCollapse SQLGenCtx
currSqlGenCtx DangerouslyCollapseBooleans -> DangerouslyCollapseBooleans -> Bool
forall a. Eq a => a -> a -> Bool
/= SQLGenCtx -> DangerouslyCollapseBooleans
dangerousBooleanCollapse SQLGenCtx
prevSqlGenCtx
            -- The bigqueryStringNumericInput of SQLGenCtx is built from the experimentalFeature, hence no need to check for this field
            -- in experimentalFeatures again.
            hasBigqueryStringNumericInputChanged :: Bool
hasBigqueryStringNumericInputChanged = SQLGenCtx -> BigQueryStringNumericInput
bigqueryStringNumericInput SQLGenCtx
currSqlGenCtx BigQueryStringNumericInput -> BigQueryStringNumericInput -> Bool
forall a. Eq a => a -> a -> Bool
/= SQLGenCtx -> BigQueryStringNumericInput
bigqueryStringNumericInput SQLGenCtx
prevSqlGenCtx
            hasHideAggregationPredicatesChanged :: Bool
hasHideAggregationPredicatesChanged = (ExperimentalFeature
EFHideAggregationPredicates ExperimentalFeature -> HashSet ExperimentalFeature -> Bool
forall a. Eq a => a -> HashSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` HashSet ExperimentalFeature
currExperimentalFeatures) Bool -> Bool -> Bool
&& (ExperimentalFeature
EFHideAggregationPredicates ExperimentalFeature -> HashSet ExperimentalFeature -> Bool
forall a. Eq a => a -> HashSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` HashSet ExperimentalFeature
prevExperimentalFeatures)
            hasHideStreamFieldsChanged :: Bool
hasHideStreamFieldsChanged = (ExperimentalFeature
EFHideStreamFields ExperimentalFeature -> HashSet ExperimentalFeature -> Bool
forall a. Eq a => a -> HashSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` HashSet ExperimentalFeature
currExperimentalFeatures) Bool -> Bool -> Bool
&& (ExperimentalFeature
EFHideStreamFields ExperimentalFeature -> HashSet ExperimentalFeature -> Bool
forall a. Eq a => a -> HashSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` HashSet ExperimentalFeature
prevExperimentalFeatures)
            hasDefaultNamingCaseChanged :: Bool
hasDefaultNamingCaseChanged = (HashSet ExperimentalFeature, NamingCase)
-> (HashSet ExperimentalFeature, NamingCase) -> Bool
hasNamingConventionChanged (HashSet ExperimentalFeature
prevExperimentalFeatures, NamingCase
prevDefaultNamingCase) (HashSet ExperimentalFeature
currExperimentalFeatures, NamingCase
currDefaultNamingCase)
        if
          -- if CORS policy has changed, close all connections
          | Bool
hasCorsPolicyChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the cors policy changed"
                ByteString
"cors policy changed"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucCorsPolicy :: CorsPolicy
ssucCorsPolicy = CorsPolicy
currCorsPolicy})
          -- if any auth config has changed, close all connections
          | Bool
hasAuthModeChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the auth mode changed"
                ByteString
"auth mode changed"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucAuthMode :: AuthMode
ssucAuthMode = AuthMode
currAuthMode})
          -- In case of allowlist, we need to check if the allowlist has changed.
          -- If the allowlist is disabled, we keep all the connections as is.
          -- If the allowlist is enabled from a disabled state, we need to close all the
          -- connections.
          | Bool
hasAllowlistEnabled ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as allow list is enabled"
                ByteString
"allow list enabled"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucEnableAllowlist :: AllowListStatus
ssucEnableAllowlist = AllowListStatus
currEnableAllowlist})
          -- If the allowlist is already enabled and there are any changes made to the
          -- allowlist, we need to close all the connections.
          | Bool
hasAllowlistUpdated ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the allow list has been updated"
                ByteString
"allow list updated"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucAllowlist :: InlinedAllowlist
ssucAllowlist = InlinedAllowlist
currAllowlist})
          -- if HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES has changed, close all connections
          | Bool
hasStringifyNumChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES setting changed"
                ByteString
"HASURA_GRAPHQL_STRINGIFY_NUMERIC_TYPES env var changed"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucSQLGenCtx :: SQLGenCtx
ssucSQLGenCtx = SQLGenCtx
currSqlGenCtx})
          -- if HASURA_GRAPHQL_V1_BOOLEAN_NULL_COLLAPSE has changed, close all connections
          | Bool
hasDangerousBooleanCollapseChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the HASURA_GRAPHQL_V1_BOOLEAN_NULL_COLLAPSE setting changed"
                ByteString
"HASURA_GRAPHQL_V1_BOOLEAN_NULL_COLLAPSE env var changed"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucSQLGenCtx :: SQLGenCtx
ssucSQLGenCtx = SQLGenCtx
currSqlGenCtx})
          -- if 'bigquery_string_numeric_input' option added/removed from experimental features, close all connections
          | Bool
hasBigqueryStringNumericInputChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the 'bigquery_string_numeric_input' option has been added/removed from HASURA_GRAPHQL_EXPERIMENTAL_FEATURES"
                ByteString
"'bigquery_string_numeric_input' removed/added in HASURA_GRAPHQL_EXPERIMENTAL_FEATURES env var"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucSQLGenCtx :: SQLGenCtx
ssucSQLGenCtx = SQLGenCtx
currSqlGenCtx})
          -- if 'hide_aggregation_predicates' option added/removed from experimental features, close all connections
          | Bool
hasHideAggregationPredicatesChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the 'hide-aggregation-predicates' option has been added/removed from HASURA_GRAPHQL_EXPERIMENTAL_FEATURES"
                ByteString
"'hide-aggregation-predicates' removed/added in HASURA_GRAPHQL_EXPERIMENTAL_FEATURES env var"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucExperimentalFeatures :: HashSet ExperimentalFeature
ssucExperimentalFeatures = HashSet ExperimentalFeature
currExperimentalFeatures})
          -- if 'hide_stream_fields' option added/removed from experimental features, close all connections
          | Bool
hasHideStreamFieldsChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the 'hide-stream-fields' option has been added/removed from HASURA_GRAPHQL_EXPERIMENTAL_FEATURES"
                ByteString
"'hide-stream-fields' removed/added in HASURA_GRAPHQL_EXPERIMENTAL_FEATURES env var"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucExperimentalFeatures :: HashSet ExperimentalFeature
ssucExperimentalFeatures = HashSet ExperimentalFeature
currExperimentalFeatures})
          -- if naming convention has been changed, close all connections
          | Bool
hasDefaultNamingCaseChanged ->
              WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
forall a.
WSServer a
-> String
-> ByteString
-> (SecuritySensitiveUserConfig -> SecuritySensitiveUserConfig)
-> IO ()
closeAllConnectionsWithReason
                WSServer a
ws
                String
"closing all websocket connections as the 'naming_convention' option has been added/removed from HASURA_GRAPHQL_EXPERIMENTAL_FEATURES and the HASURA_GRAPHQL_DEFAULT_NAMING_CONVENTION has changed"
                ByteString
"naming convention has been changed"
                (\SecuritySensitiveUserConfig
conf -> SecuritySensitiveUserConfig
conf {ssucExperimentalFeatures :: HashSet ExperimentalFeature
ssucExperimentalFeatures = HashSet ExperimentalFeature
currExperimentalFeatures, ssucDefaultNamingCase :: NamingCase
ssucDefaultNamingCase = NamingCase
currDefaultNamingCase})
          | Bool
otherwise -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

createServerApp ::
  (MonadIO m, MC.MonadBaseControl IO m, LA.Forall (LA.Pure m), MonadWSLog m) =>
  IO MetricsConfig ->
  WSConnectionInitTimeout ->
  WSServer a ->
  PrometheusMetrics ->
  -- | user provided handlers
  WSHandlers m a ->
  -- | aka WS.ServerApp
  HasuraServerApp m
{-# INLINE createServerApp #-}
createServerApp :: forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m, Forall (Pure m),
 MonadWSLog m) =>
IO MetricsConfig
-> WSConnectionInitTimeout
-> WSServer a
-> PrometheusMetrics
-> WSHandlers m a
-> HasuraServerApp m
createServerApp IO MetricsConfig
getMetricsConfig WSConnectionInitTimeout
wsConnInitTimeout (WSServer logger :: Logger Hasura
logger@(L.Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog) TVar SecuritySensitiveUserConfig
_ TVar (ServerStatus a)
serverStatus) PrometheusMetrics
prometheusMetrics WSHandlers m a
wsHandlers !IpAddress
ipAddress !PendingConnection
pendingConn = do
  WSId
wsId <- UUID -> WSId
WSId (UUID -> WSId) -> m UUID -> m WSId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UUID -> m UUID
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UUID
UUID.nextRandom
  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
EConnectionRequest Maybe WSEventInfo
forall a. Maybe a
Nothing
  -- NOTE: this timer is specific to `graphql-ws`. the server has to close the connection
  -- if the client doesn't send a `connection_init` message within the timeout period
  WSConnInitTimer
wsConnInitTimer <- IO WSConnInitTimer -> m WSConnInitTimer
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSConnInitTimer -> m WSConnInitTimer)
-> IO WSConnInitTimer -> m WSConnInitTimer
forall a b. (a -> b) -> a -> b
$ Seconds -> IO WSConnInitTimer
getNewWSTimer (Refined NonNegative Seconds -> Seconds
forall {k} (p :: k) x. Refined p x -> x
unrefine (Refined NonNegative Seconds -> Seconds)
-> Refined NonNegative Seconds -> Seconds
forall a b. (a -> b) -> a -> b
$ WSConnectionInitTimeout -> Refined NonNegative Seconds
unWSConnectionInitTimeout WSConnectionInitTimeout
wsConnInitTimeout)
  ServerStatus a
status <- IO (ServerStatus a) -> m (ServerStatus a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ServerStatus a) -> m (ServerStatus a))
-> IO (ServerStatus a) -> m (ServerStatus a)
forall a b. (a -> b) -> a -> b
$ TVar (ServerStatus a) -> IO (ServerStatus a)
forall a. TVar a -> IO a
STM.readTVarIO TVar (ServerStatus a)
serverStatus
  case ServerStatus a
status of
    AcceptingConns ConnMap a
_ -> m () -> m ()
logUnexpectedExceptions (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Either RejectRequest (AcceptWith a)
onConnRes <- WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
connHandler WSId
wsId RequestHead
reqHead IpAddress
ipAddress WSSubProtocol
subProtocol
      (RejectRequest -> m ())
-> (AcceptWith a -> m ())
-> Either RejectRequest (AcceptWith a)
-> m ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (WSId -> RejectRequest -> m ()
onReject WSId
wsId) (WSConnInitTimer -> WSId -> AcceptWith a -> m ()
onAccept WSConnInitTimer
wsConnInitTimer WSId
wsId) Either RejectRequest (AcceptWith a)
onConnRes
    ServerStatus a
ShuttingDown ->
      WSId -> RejectRequest -> m ()
onReject WSId
wsId RejectRequest
shuttingDownReject
  where
    reqHead :: RequestHead
reqHead = PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn

    getSubProtocolHeader :: RequestHead -> [(CI ByteString, ByteString)]
getSubProtocolHeader RequestHead
rhdrs =
      ((CI ByteString, ByteString) -> Bool)
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(CI ByteString
x, ByteString
_) -> CI ByteString
x CI ByteString -> CI ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk (ByteString -> CI ByteString)
-> (String -> ByteString) -> String -> CI ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B.pack (String -> CI ByteString) -> String -> CI ByteString
forall a b. (a -> b) -> a -> b
$ String
"Sec-WebSocket-Protocol")) ([(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)])
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$ RequestHead -> [(CI ByteString, ByteString)]
WS.requestHeaders RequestHead
rhdrs

    subProtocol :: WSSubProtocol
subProtocol = case RequestHead -> [(CI ByteString, ByteString)]
getSubProtocolHeader RequestHead
reqHead of
      [(CI ByteString, ByteString)
sph] -> String -> WSSubProtocol
toWSSubProtocol (String -> WSSubProtocol)
-> ((CI ByteString, ByteString) -> String)
-> (CI ByteString, ByteString)
-> WSSubProtocol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B.unpack (ByteString -> String)
-> ((CI ByteString, ByteString) -> ByteString)
-> (CI ByteString, ByteString)
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CI ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((CI ByteString, ByteString) -> WSSubProtocol)
-> (CI ByteString, ByteString) -> WSSubProtocol
forall a b. (a -> b) -> a -> b
$ (CI ByteString, ByteString)
sph
      [(CI ByteString, ByteString)]
_ -> WSSubProtocol
Apollo -- NOTE: we default to the apollo implemenation
    connHandler :: WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
connHandler = WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
forall (m :: * -> *) a.
WSHandlers m a
-> WSId
-> RequestHead
-> IpAddress
-> WSSubProtocol
-> m (Either RejectRequest (AcceptWith a))
_hOnConn WSHandlers m a
wsHandlers
    messageHandler :: WSConn a -> ByteString -> WSSubProtocol -> m ()
messageHandler = WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
forall (m :: * -> *) a.
WSHandlers m a -> WSConn a -> ByteString -> WSSubProtocol -> m ()
_hOnMessage WSHandlers m a
wsHandlers
    closeHandler :: OnCloseH m a
closeHandler = WSHandlers m a -> OnCloseH m a
forall (m :: * -> *) a. WSHandlers m a -> OnCloseH m a
_hOnClose WSHandlers m a
wsHandlers

    logUnexpectedExceptions :: m () -> m ()
logUnexpectedExceptions = (m () -> [Handler m ()] -> m ()) -> [Handler m ()] -> m () -> m ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip m () -> [Handler m ()] -> m ()
forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> [Handler m a] -> m a
catches [Handler m ()]
handlers
      where
        handlers :: [Handler m ()]
handlers =
          [ -- this exception occurs under the normal course of the web server running. Also fairly common during shutdowns.
            -- Common suggestion is to gobble it.
            -- Refer: https://hackage.haskell.org/package/warp-3.3.24/docs/src/Network.Wai.Handler.Warp.Settings.html#defaultShouldDisplayException
            (TimeoutThread -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((TimeoutThread -> m ()) -> Handler m ())
-> (TimeoutThread -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ \(TimeoutThread
_ :: TM.TimeoutThread) -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (),
            (InvalidRequest -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((InvalidRequest -> m ()) -> Handler m ())
-> (InvalidRequest -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ \(InvalidRequest
e :: Warp.InvalidRequest) -> do
              UnstructuredLog -> m ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog
                (UnstructuredLog -> m ()) -> UnstructuredLog -> m ()
forall a b. (a -> b) -> a -> b
$ LogLevel -> SerializableBlob -> UnstructuredLog
L.UnstructuredLog LogLevel
L.LevelError
                (SerializableBlob -> UnstructuredLog)
-> SerializableBlob -> UnstructuredLog
forall a b. (a -> b) -> a -> b
$ String -> SerializableBlob
forall a. IsString a => String -> a
fromString
                (String -> SerializableBlob) -> String -> SerializableBlob
forall a b. (a -> b) -> a -> b
$ String
"Client exception: "
                String -> ShowS
forall a. Semigroup a => a -> a -> a
<> InvalidRequest -> String
forall a. Show a => a -> String
show InvalidRequest
e
              InvalidRequest -> m ()
forall (m :: * -> *) e a. (MonadBase IO m, Exception e) => e -> m a
throwIO InvalidRequest
e,
            (SomeException -> m ()) -> Handler m ()
forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Handler ((SomeException -> m ()) -> Handler m ())
-> (SomeException -> m ()) -> Handler m ()
forall a b. (a -> b) -> a -> b
$ \(SomeException
e :: SomeException) -> do
              UnstructuredLog -> m ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog
                (UnstructuredLog -> m ()) -> UnstructuredLog -> m ()
forall a b. (a -> b) -> a -> b
$ LogLevel -> SerializableBlob -> UnstructuredLog
L.UnstructuredLog LogLevel
L.LevelError
                (SerializableBlob -> UnstructuredLog)
-> SerializableBlob -> UnstructuredLog
forall a b. (a -> b) -> a -> b
$ String -> SerializableBlob
forall a. IsString a => String -> a
fromString
                (String -> SerializableBlob) -> String -> SerializableBlob
forall a b. (a -> b) -> a -> b
$ String
"Unexpected exception raised in websocket. Please report this as a bug: "
                String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeException -> String
forall a. Show a => a -> String
show SomeException
e
              SomeException -> m ()
forall (m :: * -> *) e a. (MonadBase IO m, Exception e) => e -> m a
throwIO SomeException
e
          ]

    shuttingDownReject :: RejectRequest
shuttingDownReject =
      Int
-> ByteString
-> [(CI ByteString, ByteString)]
-> ByteString
-> RejectRequest
WS.RejectRequest
        Int
503
        ByteString
"Service Unavailable"
        [(CI ByteString
"Retry-After", ByteString
"0")]
        ByteString
"Server is shutting down"

    onReject :: WSId -> RejectRequest -> m ()
onReject WSId
wsId RejectRequest
rejectRequest = do
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ PendingConnection -> RejectRequest -> IO ()
WS.rejectRequestWith PendingConnection
pendingConn RejectRequest
rejectRequest
      Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
ERejected Maybe WSEventInfo
forall a. Maybe a
Nothing

    onAccept :: WSConnInitTimer -> WSId -> AcceptWith a -> m ()
onAccept WSConnInitTimer
wsConnInitTimer WSId
wsId (AcceptWith a
a AcceptRequest
acceptWithParams WSConn a -> IO ()
keepAlive WSConn a -> IO ()
onJwtExpiry) = do
      Connection
conn <- IO Connection -> m Connection
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Connection -> m Connection) -> IO Connection -> m Connection
forall a b. (a -> b) -> a -> b
$ PendingConnection -> AcceptRequest -> IO Connection
WS.acceptRequestWith PendingConnection
pendingConn AcceptRequest
acceptWithParams
      Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId WSEvent
EAccepted Maybe WSEventInfo
forall a. Maybe a
Nothing
      TQueue WSQueueResponse
sendQ <- IO (TQueue WSQueueResponse) -> m (TQueue WSQueueResponse)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (TQueue WSQueueResponse)
forall a. IO (TQueue a)
STM.newTQueueIO
      let !wsConn :: WSConn a
wsConn = WSId
-> Logger Hasura
-> Connection
-> TQueue WSQueueResponse
-> a
-> WSConn a
forall a.
WSId
-> Logger Hasura
-> Connection
-> TQueue WSQueueResponse
-> a
-> WSConn a
WSConn WSId
wsId Logger Hasura
logger Connection
conn TQueue WSQueueResponse
sendQ a
a
      -- TODO there are many thunks here. Difficult to trace how much is retained, and
      --      how much of that would be shared anyway.
      --      Requires a fork of 'wai-websockets' and 'websockets', it looks like.
      --      Adding `package` stanzas with -Xstrict -XStrictData for those two packages
      --      helped, cutting the number of thunks approximately in half.
      IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ $String
String -> WSConn a -> IO ()
forall a. String -> a -> IO ()
assertNFHere WSConn a
wsConn -- so we don't write thunks to mutable vars
      let whenAcceptingInsertConn :: m (ServerStatus a)
whenAcceptingInsertConn = IO (ServerStatus a) -> m (ServerStatus a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
            (IO (ServerStatus a) -> m (ServerStatus a))
-> IO (ServerStatus a) -> m (ServerStatus a)
forall a b. (a -> b) -> a -> b
$ STM (ServerStatus a) -> IO (ServerStatus a)
forall a. STM a -> IO a
STM.atomically
            (STM (ServerStatus a) -> IO (ServerStatus a))
-> STM (ServerStatus a) -> IO (ServerStatus a)
forall a b. (a -> b) -> a -> b
$ do
              ServerStatus a
status <- TVar (ServerStatus a) -> STM (ServerStatus a)
forall a. TVar a -> STM a
STM.readTVar TVar (ServerStatus a)
serverStatus
              case ServerStatus a
status of
                ServerStatus a
ShuttingDown -> () -> STM ()
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                AcceptingConns ConnMap a
connMap -> WSConn a -> WSId -> ConnMap a -> STM ()
forall key value.
(Eq key, Hashable key) =>
value -> key -> Map key value -> STM ()
STMMap.insert WSConn a
wsConn WSId
wsId ConnMap a
connMap
              ServerStatus a -> STM (ServerStatus a)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ServerStatus a
status

      -- ensure we clean up connMap even if an unexpected exception is raised from our worker
      -- threads, or an async exception is raised somewhere in the body here:
      m (ServerStatus a)
-> (ServerStatus a -> m ()) -> (ServerStatus a -> m ()) -> m ()
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket
        m (ServerStatus a)
whenAcceptingInsertConn
        (WSConn a -> ServerStatus a -> m ()
onConnClose WSConn a
wsConn)
        ((ServerStatus a -> m ()) -> m ())
-> (ServerStatus a -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \case
          ServerStatus a
ShuttingDown -> do
            -- Bad luck, we were in the process of shutting the server down but a new
            -- connection was accepted. Let's just close it politely
            WSConn a -> ByteString -> m ()
forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect WSConn a
wsConn ByteString
"shutting server down"
            OnCloseH m a
closeHandler WSConn a
wsConn
          AcceptingConns ConnMap a
_ -> do
            let rcv :: m ()
rcv = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  Bool
shouldCaptureVariables <- IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ MetricsConfig -> Bool
_mcAnalyzeQueryVariables (MetricsConfig -> Bool) -> IO MetricsConfig -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO MetricsConfig
getMetricsConfig
                  -- Process all messages serially (important!), in a separate thread:
                  ByteString
msg <-
                    IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                      (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$
                      -- Re-throw "receiveloop: resource vanished (Connection reset by peer)" :
                      --   https://github.com/yesodweb/wai/blob/master/warp/Network/Wai/Handler/Warp/Recv.hs#L112
                      -- as WS exception signaling cleanup below. It's not clear why exactly this gets
                      -- raised occasionally; I suspect an equivalent handler is missing from WS itself.
                      -- Regardless this should be safe:
                      (IOError -> Maybe ())
-> (() -> IO ByteString) -> IO ByteString -> IO ByteString
forall (m :: * -> *) e b a.
(MonadBaseControl IO m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> (IOError -> Bool) -> IOError -> Maybe ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Bool
E.isResourceVanishedError) (\() -> ConnectionException -> IO ByteString
forall a e. Exception e => e -> a
throw ConnectionException
WS.ConnectionClosed)
                      (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
conn
                  let messageLength :: Int64
messageLength = ByteString -> Int64
BL.length ByteString
msg
                      censoredMessage :: MessageDetails
censoredMessage =
                        SerializableBlob -> Int64 -> MessageDetails
MessageDetails
                          (ByteString -> SerializableBlob
SB.fromLBS (if Bool
shouldCaptureVariables then ByteString
msg else ByteString
"<censored>"))
                          Int64
messageLength
                  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                    (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Counter -> Int64 -> IO ()
Prometheus.Counter.add
                      (PrometheusMetrics -> Counter
pmWebSocketBytesReceived PrometheusMetrics
prometheusMetrics)
                      Int64
messageLength
                  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId (MessageDetails -> WSEvent
EMessageReceived MessageDetails
censoredMessage) Maybe WSEventInfo
forall a. Maybe a
Nothing
                  WSConn a -> ByteString -> WSSubProtocol -> m ()
messageHandler WSConn a
wsConn ByteString
msg WSSubProtocol
subProtocol

            let send :: m ()
send = m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                  WSQueueResponse ByteString
msg Maybe WSEventInfo
wsInfo IO DiffTime
wsTimer <- IO WSQueueResponse -> m WSQueueResponse
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSQueueResponse -> m WSQueueResponse)
-> IO WSQueueResponse -> m WSQueueResponse
forall a b. (a -> b) -> a -> b
$ STM WSQueueResponse -> IO WSQueueResponse
forall a. STM a -> IO a
STM.atomically (STM WSQueueResponse -> IO WSQueueResponse)
-> STM WSQueueResponse -> IO WSQueueResponse
forall a b. (a -> b) -> a -> b
$ TQueue WSQueueResponse -> STM WSQueueResponse
forall a. TQueue a -> STM a
STM.readTQueue TQueue WSQueueResponse
sendQ
                  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn ByteString
msg
                  Double
messageQueueTime <- IO Double -> m Double
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Double -> m Double) -> IO Double -> m Double
forall a b. (a -> b) -> a -> b
$ DiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (DiffTime -> Double) -> IO DiffTime -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO DiffTime
wsTimer
                  let messageLength :: Int64
messageLength = ByteString -> Int64
BL.length ByteString
msg
                      messageDetails :: MessageDetails
messageDetails = SerializableBlob -> Int64 -> MessageDetails
MessageDetails (ByteString -> SerializableBlob
SB.fromLBS ByteString
msg) Int64
messageLength
                  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                    Counter -> Int64 -> IO ()
Prometheus.Counter.add
                      (PrometheusMetrics -> Counter
pmWebSocketBytesSent PrometheusMetrics
prometheusMetrics)
                      Int64
messageLength
                    Histogram -> Double -> IO ()
Prometheus.Histogram.observe
                      (PrometheusMetrics -> Histogram
pmWebsocketMsgQueueTimeSeconds PrometheusMetrics
prometheusMetrics)
                      Double
messageQueueTime
                  Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog WSId
wsId (MessageDetails -> WSEvent
EMessageSent MessageDetails
messageDetails) Maybe WSEventInfo
wsInfo

            -- withAsync lets us be very sure that if e.g. an async exception is raised while we're
            -- forking that the threads we launched will be cleaned up. See also below.
            m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync m ()
rcv ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
rcvRef -> do
              m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync m ()
send ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
sendRef -> do
                m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> IO ()
keepAlive WSConn a
wsConn) ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
keepAliveRef -> do
                  m () -> (Async () -> m ()) -> m ()
forall (m :: * -> *) a b.
(MonadBaseControl IO m, Forall (Pure m)) =>
m a -> (Async a -> m b) -> m b
LA.withAsync (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> IO ()
onJwtExpiry WSConn a
wsConn) ((Async () -> m ()) -> m ()) -> (Async () -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Async ()
onJwtExpiryRef -> do
                    -- once connection is accepted, check the status of the timer, and if it's expired, close the connection for `graphql-ws`
                    WSConnInitTimerStatus
timeoutStatus <- IO WSConnInitTimerStatus -> m WSConnInitTimerStatus
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO WSConnInitTimerStatus -> m WSConnInitTimerStatus)
-> IO WSConnInitTimerStatus -> m WSConnInitTimerStatus
forall a b. (a -> b) -> a -> b
$ WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState WSConnInitTimer
wsConnInitTimer
                    Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WSConnInitTimerStatus
timeoutStatus WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
forall a. Eq a => a -> a -> Bool
== WSConnInitTimerStatus
Done Bool -> Bool -> Bool
&& WSSubProtocol
subProtocol WSSubProtocol -> WSSubProtocol -> Bool
forall a. Eq a => a -> a -> Bool
== WSSubProtocol
GraphQLWS)
                      (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
                      (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ WSConn a -> Word16 -> ByteString -> IO ()
forall a. WSConn a -> Word16 -> ByteString -> IO ()
closeConnWithCode WSConn a
wsConn Word16
4408 ByteString
"Connection initialisation timed out"

                    -- terminates on WS.ConnectionException and JWT expiry
                    let waitOnRefs :: [Async ()]
waitOnRefs = [Async ()
keepAliveRef, Async ()
onJwtExpiryRef, Async ()
rcvRef, Async ()
sendRef]
                    -- withAnyCancel re-raises exceptions from forkedThreads, and is guarenteed to cancel in
                    -- case of async exceptions raised while blocking here:
                    m (Async (), ()) -> m (Either ConnectionException (Async (), ()))
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try ([Async ()] -> m (Async (), ())
forall (m :: * -> *) a.
(MonadBase IO m, Forall (Pure m)) =>
[Async a] -> m (Async a, a)
LA.waitAnyCancel [Async ()]
waitOnRefs) m (Either ConnectionException (Async (), ()))
-> (Either ConnectionException (Async (), ()) -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                      -- NOTE: 'websockets' is a bit of a rat's nest at the moment wrt
                      -- exceptions; for now handle all ConnectionException by closing
                      -- and cleaning up, see: https://github.com/jaspervdj/websockets/issues/48
                      Left (ConnectionException
_ :: WS.ConnectionException) -> do
                        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
ECloseReceived Maybe WSEventInfo
forall a. Maybe a
Nothing
                      -- this will happen when jwt is expired
                      Right (Async (), ())
_ -> do
                        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
EJwtExpired Maybe WSEventInfo
forall a. Maybe a
Nothing

    onConnClose :: WSConn a -> ServerStatus a -> m ()
onConnClose WSConn a
wsConn = \case
      ServerStatus a
ShuttingDown -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      AcceptingConns ConnMap a
connMap -> do
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ STM () -> IO ()
forall a. STM a -> IO a
STM.atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ WSId -> ConnMap a -> STM ()
forall key value.
(Eq key, Hashable key) =>
key -> Map key value -> STM ()
STMMap.delete (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) ConnMap a
connMap
        OnCloseH m a
closeHandler WSConn a
wsConn
        Logger Hasura -> WSLog -> m ()
forall (m :: * -> *).
MonadWSLog m =>
Logger Hasura -> WSLog -> m ()
logWSLog Logger Hasura
logger (WSLog -> m ()) -> WSLog -> m ()
forall a b. (a -> b) -> a -> b
$ WSId -> WSEvent -> Maybe WSEventInfo -> WSLog
WSLog (WSConn a -> WSId
forall a. WSConn a -> WSId
_wcConnId WSConn a
wsConn) WSEvent
EClosed Maybe WSEventInfo
forall a. Maybe a
Nothing

shutdown :: WSServer a -> IO ()
shutdown :: forall a. WSServer a -> IO ()
shutdown (WSServer (L.Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog) TVar SecuritySensitiveUserConfig
_ TVar (ServerStatus a)
serverStatus) = do
  UnstructuredLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
writeLog (UnstructuredLog -> IO ()) -> UnstructuredLog -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> UnstructuredLog
L.debugT Text
"Shutting websockets server down"
  [(WSId, WSConn a)]
conns <- STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a. STM a -> IO a
STM.atomically (STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)])
-> STM [(WSId, WSConn a)] -> IO [(WSId, WSConn a)]
forall a b. (a -> b) -> a -> b
$ do
    [(WSId, WSConn a)]
conns <- TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
forall a. TVar (ServerStatus a) -> STM [(WSId, WSConn a)]
flushConnMap TVar (ServerStatus a)
serverStatus
    TVar (ServerStatus a) -> ServerStatus a -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar (ServerStatus a)
serverStatus ServerStatus a
forall a. ServerStatus a
ShuttingDown
    [(WSId, WSConn a)] -> STM [(WSId, WSConn a)]
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(WSId, WSConn a)]
conns

  (ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
forall a.
(ByteString -> WSConn a -> IO ())
-> ByteString -> [(WSId, WSConn a)] -> IO ()
closeAllWith ((WSConn a -> ByteString -> IO ())
-> ByteString -> WSConn a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip WSConn a -> ByteString -> IO ()
forall (m :: * -> *) a. MonadIO m => WSConn a -> ByteString -> m ()
forceConnReconnect) ByteString
"shutting server down" [(WSId, WSConn a)]
conns