{-# LANGUAGE TemplateHaskell #-}

-- | This file contains types for both the websocket protocols (Apollo) and (graphql-ws)
-- | See Apollo: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
-- | See graphql-ws: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
module Hasura.GraphQL.Transport.WebSocket.Protocol
  ( ClientMsg (CMConnInit, CMConnTerm, CMPing, CMPong, CMStart, CMStop),
    CompletionMsg (CompletionMsg),
    ConnErrMsg (ConnErrMsg, unConnErrMsg),
    ConnParams (_cpHeaders),
    DataMsg (DataMsg),
    ErrorMsg (ErrorMsg),
    OperationId (unOperationId),
    PingPongPayload,
    ServerErrorCode (..),
    ServerMsg (SMComplete, SMConnAck, SMConnErr, SMConnKeepAlive, SMData, SMErr, SMNext, SMPing, SMPong),
    ServerMsgType (..),
    StartMsg (StartMsg),
    StopMsg (StopMsg),
    WSConnInitTimerStatus (Done),
    WSSubProtocol (..),
    encodeServerErrorMsg,
    encodeServerMsg,
    getNewWSTimer,
    getWSTimerState,
    keepAliveMessage,
    showSubProtocol,
    toWSSubProtocol,

    -- * exported for testing
    unsafeMkOperationId,
  )
where

import Control.Concurrent
import Control.Concurrent.Extended (sleep)
import Control.Concurrent.STM
import Data.Aeson qualified as J
import Data.Aeson.TH qualified as J
import Data.ByteString.Lazy qualified as BL
import Data.Text (pack)
import Hasura.EncJSON
import Hasura.GraphQL.Transport.HTTP.Protocol
import Hasura.Prelude

-- NOTE: the `subProtocol` is decided based on the `Sec-WebSocket-Protocol`
-- header on every request sent to the server.
data WSSubProtocol = Apollo | GraphQLWS
  deriving (WSSubProtocol -> WSSubProtocol -> Bool
(WSSubProtocol -> WSSubProtocol -> Bool)
-> (WSSubProtocol -> WSSubProtocol -> Bool) -> Eq WSSubProtocol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WSSubProtocol -> WSSubProtocol -> Bool
$c/= :: WSSubProtocol -> WSSubProtocol -> Bool
== :: WSSubProtocol -> WSSubProtocol -> Bool
$c== :: WSSubProtocol -> WSSubProtocol -> Bool
Eq, Int -> WSSubProtocol -> ShowS
[WSSubProtocol] -> ShowS
WSSubProtocol -> String
(Int -> WSSubProtocol -> ShowS)
-> (WSSubProtocol -> String)
-> ([WSSubProtocol] -> ShowS)
-> Show WSSubProtocol
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSSubProtocol] -> ShowS
$cshowList :: [WSSubProtocol] -> ShowS
show :: WSSubProtocol -> String
$cshow :: WSSubProtocol -> String
showsPrec :: Int -> WSSubProtocol -> ShowS
$cshowsPrec :: Int -> WSSubProtocol -> ShowS
Show)

-- NOTE: Please do not change them, as they're used for to identify the type of client
-- on every request that reaches the server. They are unique to each of the protocols.
showSubProtocol :: WSSubProtocol -> String
showSubProtocol :: WSSubProtocol -> String
showSubProtocol WSSubProtocol
subProtocol = case WSSubProtocol
subProtocol of
  -- REF: https://github.com/apollographql/subscriptions-transport-ws/blob/master/src/server.ts#L144
  WSSubProtocol
Apollo -> String
"graphql-ws"
  -- REF: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#communication
  WSSubProtocol
GraphQLWS -> String
"graphql-transport-ws"

toWSSubProtocol :: String -> WSSubProtocol
toWSSubProtocol :: String -> WSSubProtocol
toWSSubProtocol String
str = case String
str of
  String
"graphql-transport-ws" -> WSSubProtocol
GraphQLWS
  String
_ -> WSSubProtocol
Apollo

-- This is set by the client when it connects to the server
newtype OperationId = OperationId {OperationId -> Text
unOperationId :: Text}
  deriving (Int -> OperationId -> ShowS
[OperationId] -> ShowS
OperationId -> String
(Int -> OperationId -> ShowS)
-> (OperationId -> String)
-> ([OperationId] -> ShowS)
-> Show OperationId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OperationId] -> ShowS
$cshowList :: [OperationId] -> ShowS
show :: OperationId -> String
$cshow :: OperationId -> String
showsPrec :: Int -> OperationId -> ShowS
$cshowsPrec :: Int -> OperationId -> ShowS
Show, OperationId -> OperationId -> Bool
(OperationId -> OperationId -> Bool)
-> (OperationId -> OperationId -> Bool) -> Eq OperationId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OperationId -> OperationId -> Bool
$c/= :: OperationId -> OperationId -> Bool
== :: OperationId -> OperationId -> Bool
$c== :: OperationId -> OperationId -> Bool
Eq, [OperationId] -> Value
[OperationId] -> Encoding
OperationId -> Value
OperationId -> Encoding
(OperationId -> Value)
-> (OperationId -> Encoding)
-> ([OperationId] -> Value)
-> ([OperationId] -> Encoding)
-> ToJSON OperationId
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
toEncodingList :: [OperationId] -> Encoding
$ctoEncodingList :: [OperationId] -> Encoding
toJSONList :: [OperationId] -> Value
$ctoJSONList :: [OperationId] -> Value
toEncoding :: OperationId -> Encoding
$ctoEncoding :: OperationId -> Encoding
toJSON :: OperationId -> Value
$ctoJSON :: OperationId -> Value
J.ToJSON, Value -> Parser [OperationId]
Value -> Parser OperationId
(Value -> Parser OperationId)
-> (Value -> Parser [OperationId]) -> FromJSON OperationId
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
parseJSONList :: Value -> Parser [OperationId]
$cparseJSONList :: Value -> Parser [OperationId]
parseJSON :: Value -> Parser OperationId
$cparseJSON :: Value -> Parser OperationId
J.FromJSON, String -> OperationId
(String -> OperationId) -> IsString OperationId
forall a. (String -> a) -> IsString a
fromString :: String -> OperationId
$cfromString :: String -> OperationId
IsString, Int -> OperationId -> Int
OperationId -> Int
(Int -> OperationId -> Int)
-> (OperationId -> Int) -> Hashable OperationId
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: OperationId -> Int
$chash :: OperationId -> Int
hashWithSalt :: Int -> OperationId -> Int
$chashWithSalt :: Int -> OperationId -> Int
Hashable)

unsafeMkOperationId :: Text -> OperationId
unsafeMkOperationId :: Text -> OperationId
unsafeMkOperationId = Text -> OperationId
OperationId

data ServerMsgType
  = -- specific to `Apollo` clients
    SMT_GQL_CONNECTION_KEEP_ALIVE
  | SMT_GQL_CONNECTION_ERROR
  | SMT_GQL_DATA
  | -- specific to `graphql-ws` clients
    SMT_GQL_NEXT
  | SMT_GQL_PING
  | SMT_GQL_PONG
  | -- common to clients of both protocols
    SMT_GQL_CONNECTION_ACK
  | SMT_GQL_ERROR
  | SMT_GQL_COMPLETE
  deriving (ServerMsgType -> ServerMsgType -> Bool
(ServerMsgType -> ServerMsgType -> Bool)
-> (ServerMsgType -> ServerMsgType -> Bool) -> Eq ServerMsgType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ServerMsgType -> ServerMsgType -> Bool
$c/= :: ServerMsgType -> ServerMsgType -> Bool
== :: ServerMsgType -> ServerMsgType -> Bool
$c== :: ServerMsgType -> ServerMsgType -> Bool
Eq)

instance Show ServerMsgType where
  show :: ServerMsgType -> String
show = \case
    -- specific to `Apollo` clients
    ServerMsgType
SMT_GQL_CONNECTION_KEEP_ALIVE -> String
"ka"
    ServerMsgType
SMT_GQL_CONNECTION_ERROR -> String
"connection_error"
    ServerMsgType
SMT_GQL_DATA -> String
"data"
    -- specific to `graphql-ws` clients
    ServerMsgType
SMT_GQL_NEXT -> String
"next"
    ServerMsgType
SMT_GQL_PING -> String
"ping"
    ServerMsgType
SMT_GQL_PONG -> String
"pong"
    -- common to clients of both protocols
    ServerMsgType
SMT_GQL_CONNECTION_ACK -> String
"connection_ack"
    ServerMsgType
SMT_GQL_ERROR -> String
"error"
    ServerMsgType
SMT_GQL_COMPLETE -> String
"complete"

instance J.ToJSON ServerMsgType where
  toJSON :: ServerMsgType -> Value
toJSON = String -> Value
forall a. ToJSON a => a -> Value
J.toJSON (String -> Value)
-> (ServerMsgType -> String) -> ServerMsgType -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerMsgType -> String
forall a. Show a => a -> String
show

data ConnParams = ConnParams
  {ConnParams -> Maybe (HashMap Text Text)
_cpHeaders :: Maybe (HashMap Text Text)}
  deriving stock (Int -> ConnParams -> ShowS
[ConnParams] -> ShowS
ConnParams -> String
(Int -> ConnParams -> ShowS)
-> (ConnParams -> String)
-> ([ConnParams] -> ShowS)
-> Show ConnParams
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnParams] -> ShowS
$cshowList :: [ConnParams] -> ShowS
show :: ConnParams -> String
$cshow :: ConnParams -> String
showsPrec :: Int -> ConnParams -> ShowS
$cshowsPrec :: Int -> ConnParams -> ShowS
Show, ConnParams -> ConnParams -> Bool
(ConnParams -> ConnParams -> Bool)
-> (ConnParams -> ConnParams -> Bool) -> Eq ConnParams
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnParams -> ConnParams -> Bool
$c/= :: ConnParams -> ConnParams -> Bool
== :: ConnParams -> ConnParams -> Bool
$c== :: ConnParams -> ConnParams -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''ConnParams)

data StartMsg = StartMsg
  { StartMsg -> OperationId
_smId :: !OperationId,
    StartMsg -> GQLReqUnparsed
_smPayload :: !GQLReqUnparsed
  }
  deriving (Int -> StartMsg -> ShowS
[StartMsg] -> ShowS
StartMsg -> String
(Int -> StartMsg -> ShowS)
-> (StartMsg -> String) -> ([StartMsg] -> ShowS) -> Show StartMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StartMsg] -> ShowS
$cshowList :: [StartMsg] -> ShowS
show :: StartMsg -> String
$cshow :: StartMsg -> String
showsPrec :: Int -> StartMsg -> ShowS
$cshowsPrec :: Int -> StartMsg -> ShowS
Show, StartMsg -> StartMsg -> Bool
(StartMsg -> StartMsg -> Bool)
-> (StartMsg -> StartMsg -> Bool) -> Eq StartMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StartMsg -> StartMsg -> Bool
$c/= :: StartMsg -> StartMsg -> Bool
== :: StartMsg -> StartMsg -> Bool
$c== :: StartMsg -> StartMsg -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''StartMsg)

data StopMsg = StopMsg
  { StopMsg -> OperationId
_stId :: OperationId
  }
  deriving (Int -> StopMsg -> ShowS
[StopMsg] -> ShowS
StopMsg -> String
(Int -> StopMsg -> ShowS)
-> (StopMsg -> String) -> ([StopMsg] -> ShowS) -> Show StopMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StopMsg] -> ShowS
$cshowList :: [StopMsg] -> ShowS
show :: StopMsg -> String
$cshow :: StopMsg -> String
showsPrec :: Int -> StopMsg -> ShowS
$cshowsPrec :: Int -> StopMsg -> ShowS
Show, StopMsg -> StopMsg -> Bool
(StopMsg -> StopMsg -> Bool)
-> (StopMsg -> StopMsg -> Bool) -> Eq StopMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StopMsg -> StopMsg -> Bool
$c/= :: StopMsg -> StopMsg -> Bool
== :: StopMsg -> StopMsg -> Bool
$c== :: StopMsg -> StopMsg -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''StopMsg)

-- Specific to graphql-ws
data PingPongPayload = PingPongPayload
  { PingPongPayload -> Maybe Text
_smMessage :: !(Maybe Text) -- NOTE: this is not within the spec, but is specific to our usecase
  }
  deriving stock (Int -> PingPongPayload -> ShowS
[PingPongPayload] -> ShowS
PingPongPayload -> String
(Int -> PingPongPayload -> ShowS)
-> (PingPongPayload -> String)
-> ([PingPongPayload] -> ShowS)
-> Show PingPongPayload
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PingPongPayload] -> ShowS
$cshowList :: [PingPongPayload] -> ShowS
show :: PingPongPayload -> String
$cshow :: PingPongPayload -> String
showsPrec :: Int -> PingPongPayload -> ShowS
$cshowsPrec :: Int -> PingPongPayload -> ShowS
Show, PingPongPayload -> PingPongPayload -> Bool
(PingPongPayload -> PingPongPayload -> Bool)
-> (PingPongPayload -> PingPongPayload -> Bool)
-> Eq PingPongPayload
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PingPongPayload -> PingPongPayload -> Bool
$c/= :: PingPongPayload -> PingPongPayload -> Bool
== :: PingPongPayload -> PingPongPayload -> Bool
$c== :: PingPongPayload -> PingPongPayload -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''PingPongPayload)

-- Specific to graphql-ws
keepAliveMessage :: PingPongPayload
keepAliveMessage :: PingPongPayload
keepAliveMessage = Maybe Text -> PingPongPayload
PingPongPayload (Maybe Text -> PingPongPayload)
-> (String -> Maybe Text) -> String -> PingPongPayload
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> (String -> Text) -> String -> Maybe Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
pack (String -> PingPongPayload) -> String -> PingPongPayload
forall a b. (a -> b) -> a -> b
$ String
"keepalive"

-- Specific to graphql-ws
data SubscribeMsg = SubscribeMsg
  { SubscribeMsg -> OperationId
_subId :: !OperationId,
    SubscribeMsg -> GQLReqUnparsed
_subPayload :: !GQLReqUnparsed
  }
  deriving (Int -> SubscribeMsg -> ShowS
[SubscribeMsg] -> ShowS
SubscribeMsg -> String
(Int -> SubscribeMsg -> ShowS)
-> (SubscribeMsg -> String)
-> ([SubscribeMsg] -> ShowS)
-> Show SubscribeMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SubscribeMsg] -> ShowS
$cshowList :: [SubscribeMsg] -> ShowS
show :: SubscribeMsg -> String
$cshow :: SubscribeMsg -> String
showsPrec :: Int -> SubscribeMsg -> ShowS
$cshowsPrec :: Int -> SubscribeMsg -> ShowS
Show, SubscribeMsg -> SubscribeMsg -> Bool
(SubscribeMsg -> SubscribeMsg -> Bool)
-> (SubscribeMsg -> SubscribeMsg -> Bool) -> Eq SubscribeMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SubscribeMsg -> SubscribeMsg -> Bool
$c/= :: SubscribeMsg -> SubscribeMsg -> Bool
== :: SubscribeMsg -> SubscribeMsg -> Bool
$c== :: SubscribeMsg -> SubscribeMsg -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''SubscribeMsg)

data ClientMsg
  = CMConnInit !(Maybe ConnParams)
  | CMStart !StartMsg
  | CMStop !StopMsg
  | -- specific to apollo clients
    CMConnTerm
  | -- specific to graphql-ws clients
    CMPing !(Maybe PingPongPayload)
  | CMPong !(Maybe PingPongPayload)
  deriving (Int -> ClientMsg -> ShowS
[ClientMsg] -> ShowS
ClientMsg -> String
(Int -> ClientMsg -> ShowS)
-> (ClientMsg -> String)
-> ([ClientMsg] -> ShowS)
-> Show ClientMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ClientMsg] -> ShowS
$cshowList :: [ClientMsg] -> ShowS
show :: ClientMsg -> String
$cshow :: ClientMsg -> String
showsPrec :: Int -> ClientMsg -> ShowS
$cshowsPrec :: Int -> ClientMsg -> ShowS
Show, ClientMsg -> ClientMsg -> Bool
(ClientMsg -> ClientMsg -> Bool)
-> (ClientMsg -> ClientMsg -> Bool) -> Eq ClientMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ClientMsg -> ClientMsg -> Bool
$c/= :: ClientMsg -> ClientMsg -> Bool
== :: ClientMsg -> ClientMsg -> Bool
$c== :: ClientMsg -> ClientMsg -> Bool
Eq)

instance J.FromJSON ClientMsg where
  parseJSON :: Value -> Parser ClientMsg
parseJSON = String -> (Object -> Parser ClientMsg) -> Value -> Parser ClientMsg
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"ClientMessage" ((Object -> Parser ClientMsg) -> Value -> Parser ClientMsg)
-> (Object -> Parser ClientMsg) -> Value -> Parser ClientMsg
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
    String
t <- Object
obj Object -> Key -> Parser String
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"type"
    case (String
t :: String) of
      String
"connection_init" -> Maybe ConnParams -> ClientMsg
CMConnInit (Maybe ConnParams -> ClientMsg)
-> Parser (Maybe ConnParams) -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser (Maybe ConnParams)
forall a. FromJSON a => Object -> Parser (Maybe a)
parsePayload Object
obj
      String
"start" -> StartMsg -> ClientMsg
CMStart (StartMsg -> ClientMsg) -> Parser StartMsg -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser StartMsg
forall a. FromJSON a => Object -> Parser a
parseObj Object
obj
      String
"stop" -> StopMsg -> ClientMsg
CMStop (StopMsg -> ClientMsg) -> Parser StopMsg -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser StopMsg
forall a. FromJSON a => Object -> Parser a
parseObj Object
obj
      String
"connection_terminate" -> ClientMsg -> Parser ClientMsg
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClientMsg
CMConnTerm
      -- graphql-ws specific message types
      String
"complete" -> StopMsg -> ClientMsg
CMStop (StopMsg -> ClientMsg) -> Parser StopMsg -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser StopMsg
forall a. FromJSON a => Object -> Parser a
parseObj Object
obj
      String
"subscribe" -> StartMsg -> ClientMsg
CMStart (StartMsg -> ClientMsg) -> Parser StartMsg -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser StartMsg
forall a. FromJSON a => Object -> Parser a
parseObj Object
obj
      String
"ping" -> Maybe PingPongPayload -> ClientMsg
CMPing (Maybe PingPongPayload -> ClientMsg)
-> Parser (Maybe PingPongPayload) -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser (Maybe PingPongPayload)
forall a. FromJSON a => Object -> Parser (Maybe a)
parsePayload Object
obj
      String
"pong" -> Maybe PingPongPayload -> ClientMsg
CMPong (Maybe PingPongPayload -> ClientMsg)
-> Parser (Maybe PingPongPayload) -> Parser ClientMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object -> Parser (Maybe PingPongPayload)
forall a. FromJSON a => Object -> Parser (Maybe a)
parsePayload Object
obj
      String
_ -> String -> Parser ClientMsg
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser ClientMsg) -> String -> Parser ClientMsg
forall a b. (a -> b) -> a -> b
$ String
"unexpected type for ClientMessage: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
t
    where
      parseObj :: Object -> Parser a
parseObj Object
o = Value -> Parser a
forall a. FromJSON a => Value -> Parser a
J.parseJSON (Object -> Value
J.Object Object
o)

      parsePayload :: Object -> Parser (Maybe a)
parsePayload Object
py = Object
py Object -> Key -> Parser (Maybe a)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"payload"

data DataMsg = DataMsg
  { DataMsg -> OperationId
_dmId :: !OperationId,
    DataMsg -> GQResponse
_dmPayload :: !GQResponse
  }

data ErrorMsg = ErrorMsg
  { ErrorMsg -> OperationId
_emId :: !OperationId,
    ErrorMsg -> Value
_emPayload :: !J.Value
  }
  deriving (Int -> ErrorMsg -> ShowS
[ErrorMsg] -> ShowS
ErrorMsg -> String
(Int -> ErrorMsg -> ShowS)
-> (ErrorMsg -> String) -> ([ErrorMsg] -> ShowS) -> Show ErrorMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ErrorMsg] -> ShowS
$cshowList :: [ErrorMsg] -> ShowS
show :: ErrorMsg -> String
$cshow :: ErrorMsg -> String
showsPrec :: Int -> ErrorMsg -> ShowS
$cshowsPrec :: Int -> ErrorMsg -> ShowS
Show, ErrorMsg -> ErrorMsg -> Bool
(ErrorMsg -> ErrorMsg -> Bool)
-> (ErrorMsg -> ErrorMsg -> Bool) -> Eq ErrorMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ErrorMsg -> ErrorMsg -> Bool
$c/= :: ErrorMsg -> ErrorMsg -> Bool
== :: ErrorMsg -> ErrorMsg -> Bool
$c== :: ErrorMsg -> ErrorMsg -> Bool
Eq)

newtype CompletionMsg = CompletionMsg {CompletionMsg -> OperationId
unCompletionMsg :: OperationId}
  deriving (Int -> CompletionMsg -> ShowS
[CompletionMsg] -> ShowS
CompletionMsg -> String
(Int -> CompletionMsg -> ShowS)
-> (CompletionMsg -> String)
-> ([CompletionMsg] -> ShowS)
-> Show CompletionMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CompletionMsg] -> ShowS
$cshowList :: [CompletionMsg] -> ShowS
show :: CompletionMsg -> String
$cshow :: CompletionMsg -> String
showsPrec :: Int -> CompletionMsg -> ShowS
$cshowsPrec :: Int -> CompletionMsg -> ShowS
Show, CompletionMsg -> CompletionMsg -> Bool
(CompletionMsg -> CompletionMsg -> Bool)
-> (CompletionMsg -> CompletionMsg -> Bool) -> Eq CompletionMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompletionMsg -> CompletionMsg -> Bool
$c/= :: CompletionMsg -> CompletionMsg -> Bool
== :: CompletionMsg -> CompletionMsg -> Bool
$c== :: CompletionMsg -> CompletionMsg -> Bool
Eq)

instance J.FromJSON CompletionMsg where
  parseJSON :: Value -> Parser CompletionMsg
parseJSON = String
-> (Object -> Parser CompletionMsg)
-> Value
-> Parser CompletionMsg
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"CompletionMsg" ((Object -> Parser CompletionMsg) -> Value -> Parser CompletionMsg)
-> (Object -> Parser CompletionMsg)
-> Value
-> Parser CompletionMsg
forall a b. (a -> b) -> a -> b
$ \Object
t ->
    OperationId -> CompletionMsg
CompletionMsg (OperationId -> CompletionMsg)
-> Parser OperationId -> Parser CompletionMsg
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
t Object -> Key -> Parser OperationId
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"id"

instance J.ToJSON CompletionMsg where
  toJSON :: CompletionMsg -> Value
toJSON (CompletionMsg OperationId
opId) = Text -> Value
J.String (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ OperationId -> Text
forall a. Show a => a -> Text
tshow OperationId
opId

newtype ConnErrMsg = ConnErrMsg {ConnErrMsg -> Text
unConnErrMsg :: Text}
  deriving (Int -> ConnErrMsg -> ShowS
[ConnErrMsg] -> ShowS
ConnErrMsg -> String
(Int -> ConnErrMsg -> ShowS)
-> (ConnErrMsg -> String)
-> ([ConnErrMsg] -> ShowS)
-> Show ConnErrMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnErrMsg] -> ShowS
$cshowList :: [ConnErrMsg] -> ShowS
show :: ConnErrMsg -> String
$cshow :: ConnErrMsg -> String
showsPrec :: Int -> ConnErrMsg -> ShowS
$cshowsPrec :: Int -> ConnErrMsg -> ShowS
Show, ConnErrMsg -> ConnErrMsg -> Bool
(ConnErrMsg -> ConnErrMsg -> Bool)
-> (ConnErrMsg -> ConnErrMsg -> Bool) -> Eq ConnErrMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnErrMsg -> ConnErrMsg -> Bool
$c/= :: ConnErrMsg -> ConnErrMsg -> Bool
== :: ConnErrMsg -> ConnErrMsg -> Bool
$c== :: ConnErrMsg -> ConnErrMsg -> Bool
Eq, [ConnErrMsg] -> Value
[ConnErrMsg] -> Encoding
ConnErrMsg -> Value
ConnErrMsg -> Encoding
(ConnErrMsg -> Value)
-> (ConnErrMsg -> Encoding)
-> ([ConnErrMsg] -> Value)
-> ([ConnErrMsg] -> Encoding)
-> ToJSON ConnErrMsg
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
toEncodingList :: [ConnErrMsg] -> Encoding
$ctoEncodingList :: [ConnErrMsg] -> Encoding
toJSONList :: [ConnErrMsg] -> Value
$ctoJSONList :: [ConnErrMsg] -> Value
toEncoding :: ConnErrMsg -> Encoding
$ctoEncoding :: ConnErrMsg -> Encoding
toJSON :: ConnErrMsg -> Value
$ctoJSON :: ConnErrMsg -> Value
J.ToJSON, Value -> Parser [ConnErrMsg]
Value -> Parser ConnErrMsg
(Value -> Parser ConnErrMsg)
-> (Value -> Parser [ConnErrMsg]) -> FromJSON ConnErrMsg
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
parseJSONList :: Value -> Parser [ConnErrMsg]
$cparseJSONList :: Value -> Parser [ConnErrMsg]
parseJSON :: Value -> Parser ConnErrMsg
$cparseJSON :: Value -> Parser ConnErrMsg
J.FromJSON, String -> ConnErrMsg
(String -> ConnErrMsg) -> IsString ConnErrMsg
forall a. (String -> a) -> IsString a
fromString :: String -> ConnErrMsg
$cfromString :: String -> ConnErrMsg
IsString)

data ServerErrorMsg = ServerErrorMsg {ServerErrorMsg -> Text
unServerErrorMsg :: Text}
  deriving stock (Int -> ServerErrorMsg -> ShowS
[ServerErrorMsg] -> ShowS
ServerErrorMsg -> String
(Int -> ServerErrorMsg -> ShowS)
-> (ServerErrorMsg -> String)
-> ([ServerErrorMsg] -> ShowS)
-> Show ServerErrorMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ServerErrorMsg] -> ShowS
$cshowList :: [ServerErrorMsg] -> ShowS
show :: ServerErrorMsg -> String
$cshow :: ServerErrorMsg -> String
showsPrec :: Int -> ServerErrorMsg -> ShowS
$cshowsPrec :: Int -> ServerErrorMsg -> ShowS
Show, ServerErrorMsg -> ServerErrorMsg -> Bool
(ServerErrorMsg -> ServerErrorMsg -> Bool)
-> (ServerErrorMsg -> ServerErrorMsg -> Bool) -> Eq ServerErrorMsg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ServerErrorMsg -> ServerErrorMsg -> Bool
$c/= :: ServerErrorMsg -> ServerErrorMsg -> Bool
== :: ServerErrorMsg -> ServerErrorMsg -> Bool
$c== :: ServerErrorMsg -> ServerErrorMsg -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''ServerErrorMsg)

data ServerMsg
  = SMConnAck
  | SMConnKeepAlive
  | SMConnErr !ConnErrMsg
  | SMData !DataMsg
  | SMErr !ErrorMsg
  | SMComplete !CompletionMsg
  | -- graphql-ws specific values
    SMNext !DataMsg
  | SMPing !(Maybe PingPongPayload)
  | SMPong !(Maybe PingPongPayload)

-- | This is sent from the server to the client while closing the websocket
--   on encountering an error.
data ServerErrorCode
  = ProtocolError1002
  | GenericError4400 !String
  | Unauthorized4401
  | Forbidden4403
  | ConnectionInitTimeout4408
  | NonUniqueSubscription4409 !OperationId
  | TooManyRequests4429
  deriving stock (Int -> ServerErrorCode -> ShowS
[ServerErrorCode] -> ShowS
ServerErrorCode -> String
(Int -> ServerErrorCode -> ShowS)
-> (ServerErrorCode -> String)
-> ([ServerErrorCode] -> ShowS)
-> Show ServerErrorCode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ServerErrorCode] -> ShowS
$cshowList :: [ServerErrorCode] -> ShowS
show :: ServerErrorCode -> String
$cshow :: ServerErrorCode -> String
showsPrec :: Int -> ServerErrorCode -> ShowS
$cshowsPrec :: Int -> ServerErrorCode -> ShowS
Show)

encodeServerErrorMsg :: ServerErrorCode -> BL.ByteString
encodeServerErrorMsg :: ServerErrorCode -> ByteString
encodeServerErrorMsg ServerErrorCode
ecode = EncJSON -> ByteString
encJToLBS (EncJSON -> ByteString)
-> (ServerErrorMsg -> EncJSON) -> ServerErrorMsg -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ServerErrorMsg -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue (ServerErrorMsg -> ByteString) -> ServerErrorMsg -> ByteString
forall a b. (a -> b) -> a -> b
$ case ServerErrorCode
ecode of
  ServerErrorCode
ProtocolError1002 -> String -> ServerErrorMsg
packMsg String
"1002: Protocol Error"
  GenericError4400 String
msg -> String -> ServerErrorMsg
packMsg (String -> ServerErrorMsg) -> String -> ServerErrorMsg
forall a b. (a -> b) -> a -> b
$ String
"4400: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
  ServerErrorCode
Unauthorized4401 -> String -> ServerErrorMsg
packMsg String
"4401: Unauthorized"
  ServerErrorCode
Forbidden4403 -> String -> ServerErrorMsg
packMsg String
"4403: Forbidden"
  ServerErrorCode
ConnectionInitTimeout4408 -> String -> ServerErrorMsg
packMsg String
"4408: Connection initialisation timeout"
  NonUniqueSubscription4409 OperationId
opId -> String -> ServerErrorMsg
packMsg (String -> ServerErrorMsg) -> String -> ServerErrorMsg
forall a b. (a -> b) -> a -> b
$ String
"4409: Subscriber for " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> OperationId -> String
forall a. Show a => a -> String
show OperationId
opId String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" already exists"
  ServerErrorCode
TooManyRequests4429 -> String -> ServerErrorMsg
packMsg String
"4429: Too many requests"
  where
    packMsg :: String -> ServerErrorMsg
packMsg = Text -> ServerErrorMsg
ServerErrorMsg (Text -> ServerErrorMsg)
-> (String -> Text) -> String -> ServerErrorMsg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
pack

encodeServerMsg :: ServerMsg -> BL.ByteString
encodeServerMsg :: ServerMsg -> ByteString
encodeServerMsg ServerMsg
msg =
  EncJSON -> ByteString
encJToLBS (EncJSON -> ByteString) -> EncJSON -> ByteString
forall a b. (a -> b) -> a -> b
$
    [(Text, EncJSON)] -> EncJSON
encJFromAssocList ([(Text, EncJSON)] -> EncJSON) -> [(Text, EncJSON)] -> EncJSON
forall a b. (a -> b) -> a -> b
$ case ServerMsg
msg of
      ServerMsg
SMConnAck ->
        [ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_CONNECTION_ACK]
      ServerMsg
SMConnKeepAlive ->
        [ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_CONNECTION_KEEP_ALIVE]
      SMConnErr ConnErrMsg
connErr ->
        [ ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_CONNECTION_ERROR,
          (Text
"payload", ConnErrMsg -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue ConnErrMsg
connErr)
        ]
      SMData (DataMsg OperationId
opId GQResponse
payload) ->
        [ ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_DATA,
          (Text
"id", OperationId -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue OperationId
opId),
          (Text
"payload", GQResponse -> EncJSON
encodeGQResp GQResponse
payload)
        ]
      SMErr (ErrorMsg OperationId
opId Value
payload) ->
        [ ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_ERROR,
          (Text
"id", OperationId -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue OperationId
opId),
          (Text
"payload", Value -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue Value
payload)
        ]
      SMComplete CompletionMsg
compMsg ->
        [ ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_COMPLETE,
          (Text
"id", OperationId -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue (OperationId -> EncJSON) -> OperationId -> EncJSON
forall a b. (a -> b) -> a -> b
$ CompletionMsg -> OperationId
unCompletionMsg CompletionMsg
compMsg)
        ]
      SMPing Maybe PingPongPayload
mPayload ->
        Maybe PingPongPayload -> ServerMsgType -> [(Text, EncJSON)]
forall a a a.
(IsString a, ToJSON a, ToJSON a) =>
Maybe a -> a -> [(a, EncJSON)]
encodePingPongPayload Maybe PingPongPayload
mPayload ServerMsgType
SMT_GQL_PING
      SMPong Maybe PingPongPayload
mPayload ->
        Maybe PingPongPayload -> ServerMsgType -> [(Text, EncJSON)]
forall a a a.
(IsString a, ToJSON a, ToJSON a) =>
Maybe a -> a -> [(a, EncJSON)]
encodePingPongPayload Maybe PingPongPayload
mPayload ServerMsgType
SMT_GQL_PONG
      SMNext (DataMsg OperationId
opId GQResponse
payload) ->
        [ ServerMsgType -> (Text, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy ServerMsgType
SMT_GQL_NEXT,
          (Text
"id", OperationId -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue OperationId
opId),
          (Text
"payload", GQResponse -> EncJSON
encodeGQResp GQResponse
payload)
        ]
  where
    encTy :: a -> (a, EncJSON)
encTy a
ty = (a
"type", a -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue a
ty)

    encodePingPongPayload :: Maybe a -> a -> [(a, EncJSON)]
encodePingPongPayload Maybe a
mPayload a
msgType = case Maybe a
mPayload of
      Just a
payload ->
        [ a -> (a, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy a
msgType,
          (a
"payload", a -> EncJSON
forall a. ToJSON a => a -> EncJSON
encJFromJValue a
payload)
        ]
      Maybe a
Nothing -> [a -> (a, EncJSON)
forall a a. (IsString a, ToJSON a) => a -> (a, EncJSON)
encTy a
msgType]

-- This "timer" is necessary while initialising the connection
-- with the server. Also, this is specific to the GraphQL-WS protocol.
data WSConnInitTimerStatus = Running | Done
  deriving stock (Int -> WSConnInitTimerStatus -> ShowS
[WSConnInitTimerStatus] -> ShowS
WSConnInitTimerStatus -> String
(Int -> WSConnInitTimerStatus -> ShowS)
-> (WSConnInitTimerStatus -> String)
-> ([WSConnInitTimerStatus] -> ShowS)
-> Show WSConnInitTimerStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WSConnInitTimerStatus] -> ShowS
$cshowList :: [WSConnInitTimerStatus] -> ShowS
show :: WSConnInitTimerStatus -> String
$cshow :: WSConnInitTimerStatus -> String
showsPrec :: Int -> WSConnInitTimerStatus -> ShowS
$cshowsPrec :: Int -> WSConnInitTimerStatus -> ShowS
Show, WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
(WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool)
-> (WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool)
-> Eq WSConnInitTimerStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
$c/= :: WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
== :: WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
$c== :: WSConnInitTimerStatus -> WSConnInitTimerStatus -> Bool
Eq)

type WSConnInitTimer = (TVar WSConnInitTimerStatus, TMVar ())

getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState :: WSConnInitTimer -> IO WSConnInitTimerStatus
getWSTimerState (TVar WSConnInitTimerStatus
timerState, TMVar ()
_) = TVar WSConnInitTimerStatus -> IO WSConnInitTimerStatus
forall a. TVar a -> IO a
readTVarIO TVar WSConnInitTimerStatus
timerState

getNewWSTimer :: Seconds -> IO WSConnInitTimer
getNewWSTimer :: Seconds -> IO WSConnInitTimer
getNewWSTimer Seconds
timeout = do
  TVar WSConnInitTimerStatus
timerState <- WSConnInitTimerStatus -> IO (TVar WSConnInitTimerStatus)
forall a. a -> IO (TVar a)
newTVarIO WSConnInitTimerStatus
Running
  TMVar ()
timer <- IO (TMVar ())
forall a. IO (TMVar a)
newEmptyTMVarIO
  IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
    IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
      DiffTime -> IO ()
sleep (Seconds -> DiffTime
seconds Seconds
timeout)
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        WSConnInitTimerStatus
runTimerState <- TVar WSConnInitTimerStatus -> STM WSConnInitTimerStatus
forall a. TVar a -> STM a
readTVar TVar WSConnInitTimerStatus
timerState
        case WSConnInitTimerStatus
runTimerState of
          WSConnInitTimerStatus
Running -> do
            -- time's up, we set status to "Done"
            TVar WSConnInitTimerStatus -> WSConnInitTimerStatus -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar WSConnInitTimerStatus
timerState WSConnInitTimerStatus
Done
            TMVar () -> () -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar ()
timer ()
          WSConnInitTimerStatus
Done -> () -> STM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  WSConnInitTimer -> IO WSConnInitTimer
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TVar WSConnInitTimerStatus
timerState, TMVar ()
timer)