module Hasura.Server.Auth.WebHook
  ( AuthHookType (..),
    AuthHook (..),
    userInfoFromAuthHook,
  )
where

import Control.Exception.Lifted (try)
import Control.Lens
import Control.Monad.Trans.Control (MonadBaseControl)
import Data.Aeson
import Data.Aeson qualified as J
import Data.ByteString.Lazy qualified as BL
import Data.HashMap.Strict qualified as Map
import Data.Parser.CacheControl (parseMaxAge)
import Data.Parser.Expires
import Data.Text qualified as T
import Data.Time.Clock (UTCTime, addUTCTime, getCurrentTime)
import Hasura.Base.Error
import Hasura.GraphQL.Transport.HTTP.Protocol qualified as GH
import Hasura.HTTP
import Hasura.Logging
import Hasura.Prelude
import Hasura.Server.Logging
import Hasura.Server.Utils
import Hasura.Session
import Hasura.Tracing qualified as Tracing
import Network.HTTP.Client.Transformable qualified as HTTP
import Network.Wreq qualified as Wreq

data AuthHookType
  = AHTGet
  | AHTPost
  deriving (AuthHookType -> AuthHookType -> Bool
(AuthHookType -> AuthHookType -> Bool)
-> (AuthHookType -> AuthHookType -> Bool) -> Eq AuthHookType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthHookType -> AuthHookType -> Bool
$c/= :: AuthHookType -> AuthHookType -> Bool
== :: AuthHookType -> AuthHookType -> Bool
$c== :: AuthHookType -> AuthHookType -> Bool
Eq)

instance Show AuthHookType where
  show :: AuthHookType -> String
show AuthHookType
AHTGet = String
"GET"
  show AuthHookType
AHTPost = String
"POST"

data AuthHook = AuthHook
  { AuthHook -> Text
ahUrl :: Text,
    AuthHook -> AuthHookType
ahType :: AuthHookType
  }
  deriving (Int -> AuthHook -> ShowS
[AuthHook] -> ShowS
AuthHook -> String
(Int -> AuthHook -> ShowS)
-> (AuthHook -> String) -> ([AuthHook] -> ShowS) -> Show AuthHook
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AuthHook] -> ShowS
$cshowList :: [AuthHook] -> ShowS
show :: AuthHook -> String
$cshow :: AuthHook -> String
showsPrec :: Int -> AuthHook -> ShowS
$cshowsPrec :: Int -> AuthHook -> ShowS
Show, AuthHook -> AuthHook -> Bool
(AuthHook -> AuthHook -> Bool)
-> (AuthHook -> AuthHook -> Bool) -> Eq AuthHook
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AuthHook -> AuthHook -> Bool
$c/= :: AuthHook -> AuthHook -> Bool
== :: AuthHook -> AuthHook -> Bool
$c== :: AuthHook -> AuthHook -> Bool
Eq)

hookMethod :: AuthHook -> HTTP.StdMethod
hookMethod :: AuthHook -> StdMethod
hookMethod AuthHook
authHook = case AuthHook -> AuthHookType
ahType AuthHook
authHook of
  AuthHookType
AHTGet -> StdMethod
HTTP.GET
  AuthHookType
AHTPost -> StdMethod
HTTP.POST

-- | Makes an authentication request to the given AuthHook and returns
--   UserInfo parsed from the response, plus an expiration time if one
--   was returned. Optionally passes a batch of raw GraphQL requests
--   for finer-grained auth. (#2666)
userInfoFromAuthHook ::
  forall m.
  (MonadIO m, MonadBaseControl IO m, MonadError QErr m, Tracing.MonadTrace m) =>
  Logger Hasura ->
  HTTP.Manager ->
  AuthHook ->
  [HTTP.Header] ->
  Maybe GH.ReqsText ->
  m (UserInfo, Maybe UTCTime, [HTTP.Header])
userInfoFromAuthHook :: Logger Hasura
-> Manager
-> AuthHook
-> [Header]
-> Maybe ReqsText
-> m (UserInfo, Maybe UTCTime, [Header])
userInfoFromAuthHook Logger Hasura
logger Manager
manager AuthHook
hook [Header]
reqHeaders Maybe ReqsText
reqs = do
  Response ByteString
resp <- (Either HttpException (Response ByteString)
-> (HttpException -> m (Response ByteString))
-> m (Response ByteString)
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` HttpException -> m (Response ByteString)
forall a. HttpException -> m a
logAndThrow) (Either HttpException (Response ByteString)
 -> m (Response ByteString))
-> m (Either HttpException (Response ByteString))
-> m (Response ByteString)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Response ByteString)
-> m (Either HttpException (Response ByteString))
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try m (Response ByteString)
performHTTPRequest
  let status :: Status
status = Response ByteString
resp Response ByteString
-> Getting Status (Response ByteString) Status -> Status
forall s a. s -> Getting a s a -> a
^. Getting Status (Response ByteString) Status
forall body. Lens' (Response body) Status
Wreq.responseStatus
      respBody :: ByteString
respBody = Response ByteString
resp Response ByteString
-> Getting ByteString (Response ByteString) ByteString
-> ByteString
forall s a. s -> Getting a s a -> a
^. Getting ByteString (Response ByteString) ByteString
forall body0 body1.
Lens (Response body0) (Response body1) body0 body1
Wreq.responseBody
      cookieHeaders :: [Header]
cookieHeaders = (Header -> Bool) -> [Header] -> [Header]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(HeaderName
headerName, ByteString
_) -> HeaderName
headerName HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
"Set-Cookie") (Response ByteString
resp Response ByteString
-> Getting [Header] (Response ByteString) [Header] -> [Header]
forall s a. s -> Getting a s a -> a
^. Getting [Header] (Response ByteString) [Header]
forall body. Lens' (Response body) [Header]
Wreq.responseHeaders)

  Logger Hasura
-> Text
-> StdMethod
-> Status
-> ByteString
-> [Header]
-> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *).
(MonadIO m, MonadError QErr m) =>
Logger Hasura
-> Text
-> StdMethod
-> Status
-> ByteString
-> [Header]
-> m (UserInfo, Maybe UTCTime, [Header])
mkUserInfoFromResp Logger Hasura
logger (AuthHook -> Text
ahUrl AuthHook
hook) (AuthHook -> StdMethod
hookMethod AuthHook
hook) Status
status ByteString
respBody [Header]
cookieHeaders
  where
    performHTTPRequest :: m (Wreq.Response BL.ByteString)
    performHTTPRequest :: m (Response ByteString)
performHTTPRequest = do
      let url :: String
url = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ AuthHook -> Text
ahUrl AuthHook
hook
      Request
req <- IO Request -> m Request
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> m Request) -> IO Request -> m Request
forall a b. (a -> b) -> a -> b
$ Text -> IO Request
forall (m :: * -> *). MonadThrow m => Text -> m Request
HTTP.mkRequestThrow (Text -> IO Request) -> Text -> IO Request
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
url
      Request
-> (Request -> m (Response ByteString)) -> m (Response ByteString)
forall (m :: * -> *) a.
MonadTrace m =>
Request -> (Request -> m a) -> m a
Tracing.tracedHttpRequest Request
req \Request
req' -> IO (Response ByteString) -> m (Response ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
        case AuthHook -> AuthHookType
ahType AuthHook
hook of
          AuthHookType
AHTGet -> do
            let isCommonHeader :: HeaderName -> Bool
isCommonHeader = (HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
forall a. IsString a => [a]
commonClientHeadersIgnored)
                filteredHeaders :: [Header]
filteredHeaders = (Header -> Bool) -> [Header] -> [Header]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Header -> Bool) -> Header -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> Bool
isCommonHeader (HeaderName -> Bool) -> (Header -> HeaderName) -> Header -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> HeaderName
forall a b. (a, b) -> a
fst) [Header]
reqHeaders
                req'' :: Request
req'' = Request
req' Request -> (Request -> Request) -> Request
forall a b. a -> (a -> b) -> b
& ASetter Request Request [Header] [Header]
-> [Header] -> Request -> Request
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter Request Request [Header] [Header]
Lens' Request [Header]
HTTP.headers ([Header] -> [Header]
addDefaultHeaders [Header]
filteredHeaders)
            Request -> Manager -> IO (Response ByteString)
HTTP.performRequest Request
req'' Manager
manager
          AuthHookType
AHTPost -> do
            let contentType :: Header
contentType = (HeaderName
"Content-Type", ByteString
"application/json")
                headersPayload :: Value
headersPayload = HashMap Text Text -> Value
forall a. ToJSON a => a -> Value
J.toJSON (HashMap Text Text -> Value) -> HashMap Text Text -> Value
forall a b. (a -> b) -> a -> b
$ [(Text, Text)] -> HashMap Text Text
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
Map.fromList ([(Text, Text)] -> HashMap Text Text)
-> [(Text, Text)] -> HashMap Text Text
forall a b. (a -> b) -> a -> b
$ [Header] -> [(Text, Text)]
hdrsToText [Header]
reqHeaders
                req'' :: Request
req'' =
                  Request
req Request -> (Request -> Request) -> Request
forall a b. a -> (a -> b) -> b
& ASetter Request Request ByteString ByteString
-> ByteString -> Request -> Request
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter Request Request ByteString ByteString
Lens' Request ByteString
HTTP.method ByteString
"POST"
                    Request -> (Request -> Request) -> Request
forall a b. a -> (a -> b) -> b
& ASetter Request Request [Header] [Header]
-> [Header] -> Request -> Request
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter Request Request [Header] [Header]
Lens' Request [Header]
HTTP.headers ([Header] -> [Header]
addDefaultHeaders [Header
contentType])
                    Request -> (Request -> Request) -> Request
forall a b. a -> (a -> b) -> b
& ASetter Request Request (Maybe ByteString) (Maybe ByteString)
-> Maybe ByteString -> Request -> Request
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter Request Request (Maybe ByteString) (Maybe ByteString)
Lens' Request (Maybe ByteString)
HTTP.body (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Value -> ByteString
forall a. ToJSON a => a -> ByteString
J.encode (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$ [Pair] -> Value
object [Key
"headers" Key -> Value -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Value
headersPayload, Key
"request" Key -> Maybe ReqsText -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Maybe ReqsText
reqs])
            Request -> Manager -> IO (Response ByteString)
HTTP.performRequest Request
req'' Manager
manager

    logAndThrow :: HTTP.HttpException -> m a
    logAndThrow :: HttpException -> m a
logAndThrow HttpException
err = do
      Logger Hasura
-> forall a (m :: * -> *).
   (ToEngineLog a Hasura, MonadIO m) =>
   a -> m ()
forall impl.
Logger impl
-> forall a (m :: * -> *).
   (ToEngineLog a impl, MonadIO m) =>
   a -> m ()
unLogger Logger Hasura
logger (WebHookLog -> m ()) -> WebHookLog -> m ()
forall a b. (a -> b) -> a -> b
$
        LogLevel
-> Maybe Status
-> Text
-> StdMethod
-> Maybe HttpException
-> Maybe Text
-> Maybe Text
-> WebHookLog
WebHookLog
          LogLevel
LevelError
          Maybe Status
forall a. Maybe a
Nothing
          (AuthHook -> Text
ahUrl AuthHook
hook)
          (AuthHook -> StdMethod
hookMethod AuthHook
hook)
          (HttpException -> Maybe HttpException
forall a. a -> Maybe a
Just (HttpException -> Maybe HttpException)
-> HttpException -> Maybe HttpException
forall a b. (a -> b) -> a -> b
$ HttpException -> HttpException
HttpException HttpException
err)
          Maybe Text
forall a. Maybe a
Nothing
          Maybe Text
forall a. Maybe a
Nothing
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 Text
"webhook authentication request failed"

mkUserInfoFromResp ::
  (MonadIO m, MonadError QErr m) =>
  Logger Hasura ->
  Text ->
  HTTP.StdMethod ->
  HTTP.Status ->
  BL.ByteString ->
  [HTTP.Header] ->
  m (UserInfo, Maybe UTCTime, [HTTP.Header])
mkUserInfoFromResp :: Logger Hasura
-> Text
-> StdMethod
-> Status
-> ByteString
-> [Header]
-> m (UserInfo, Maybe UTCTime, [Header])
mkUserInfoFromResp (Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger) Text
url StdMethod
method Status
statusCode ByteString
respBody [Header]
respHdrs
  | Status
statusCode Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
HTTP.status200 =
    case ByteString -> Either String (HashMap Text Text)
forall a. FromJSON a => ByteString -> Either String a
eitherDecode ByteString
respBody of
      Left String
e -> do
        m ()
logError
        Text -> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m (UserInfo, Maybe UTCTime, [Header]))
-> Text -> m (UserInfo, Maybe UTCTime, [Header])
forall a b. (a -> b) -> a -> b
$ Text
"Invalid response from authorization hook: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
e
      Right HashMap Text Text
rawHeaders -> HashMap Text Text
-> [Header] -> m (UserInfo, Maybe UTCTime, [Header])
getUserInfoFromHdrs HashMap Text Text
rawHeaders [Header]
respHdrs
  | Status
statusCode Status -> Status -> Bool
forall a. Eq a => a -> a -> Bool
== Status
HTTP.status401 = do
    m ()
logError
    Text -> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *) a. QErrM m => Text -> m a
throw401 Text
"Authentication hook unauthorized this request"
  | Bool
otherwise = do
    m ()
logError
    Text -> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 Text
"Invalid response from authorization hook"
  where
    getUserInfoFromHdrs :: HashMap Text Text
-> [Header] -> m (UserInfo, Maybe UTCTime, [Header])
getUserInfoFromHdrs HashMap Text Text
rawHeaders [Header]
responseHdrs = do
      UserInfo
userInfo <-
        UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
forall (m :: * -> *).
MonadError QErr m =>
UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
mkUserInfo UserRoleBuild
URBFromSessionVariables UserAdminSecret
UAdminSecretNotSent (SessionVariables -> m UserInfo) -> SessionVariables -> m UserInfo
forall a b. (a -> b) -> a -> b
$
          HashMap Text Text -> SessionVariables
mkSessionVariablesText HashMap Text Text
rawHeaders
      LogLevel -> Maybe ByteString -> Maybe Text -> m ()
forall (m :: * -> *).
MonadIO m =>
LogLevel -> Maybe ByteString -> Maybe Text -> m ()
logWebHookResp LogLevel
LevelInfo Maybe ByteString
forall a. Maybe a
Nothing Maybe Text
forall a. Maybe a
Nothing
      Maybe UTCTime
expiration <- MaybeT m UTCTime -> m (Maybe UTCTime)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m UTCTime -> m (Maybe UTCTime))
-> MaybeT m UTCTime -> m (Maybe UTCTime)
forall a b. (a -> b) -> a -> b
$ HashMap Text Text -> MaybeT m UTCTime
timeFromCacheControl HashMap Text Text
rawHeaders MaybeT m UTCTime -> MaybeT m UTCTime -> MaybeT m UTCTime
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> HashMap Text Text -> MaybeT m UTCTime
timeFromExpires HashMap Text Text
rawHeaders
      (UserInfo, Maybe UTCTime, [Header])
-> m (UserInfo, Maybe UTCTime, [Header])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UserInfo
userInfo, Maybe UTCTime
expiration, [Header]
responseHdrs)

    logWebHookResp :: MonadIO m => LogLevel -> Maybe BL.ByteString -> Maybe Text -> m ()
    logWebHookResp :: LogLevel -> Maybe ByteString -> Maybe Text -> m ()
logWebHookResp LogLevel
logLevel Maybe ByteString
mResp Maybe Text
message =
      WebHookLog -> m ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger (WebHookLog -> m ()) -> WebHookLog -> m ()
forall a b. (a -> b) -> a -> b
$
        LogLevel
-> Maybe Status
-> Text
-> StdMethod
-> Maybe HttpException
-> Maybe Text
-> Maybe Text
-> WebHookLog
WebHookLog
          LogLevel
logLevel
          (Status -> Maybe Status
forall a. a -> Maybe a
Just Status
statusCode)
          Text
url
          StdMethod
method
          Maybe HttpException
forall a. Maybe a
Nothing
          (ByteString -> Text
bsToTxt (ByteString -> Text)
-> (ByteString -> ByteString) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BL.toStrict (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
mResp)
          Maybe Text
message
    logWarn :: Text -> MaybeT m ()
logWarn Text
message = LogLevel -> Maybe ByteString -> Maybe Text -> MaybeT m ()
forall (m :: * -> *).
MonadIO m =>
LogLevel -> Maybe ByteString -> Maybe Text -> m ()
logWebHookResp LogLevel
LevelWarn (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
respBody) (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
message)
    logError :: m ()
logError = LogLevel -> Maybe ByteString -> Maybe Text -> m ()
forall (m :: * -> *).
MonadIO m =>
LogLevel -> Maybe ByteString -> Maybe Text -> m ()
logWebHookResp LogLevel
LevelError (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
respBody) Maybe Text
forall a. Maybe a
Nothing

    timeFromCacheControl :: HashMap Text Text -> MaybeT m UTCTime
timeFromCacheControl HashMap Text Text
headers = do
      Text
header <- Maybe Text -> MaybeT m Text
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t a -> f a
afold (Maybe Text -> MaybeT m Text) -> Maybe Text -> MaybeT m Text
forall a b. (a -> b) -> a -> b
$ Text -> HashMap Text Text -> Maybe Text
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup Text
"Cache-Control" HashMap Text Text
headers
      Integer
duration <- Text -> Either String Integer
forall a. Integral a => Text -> Either String a
parseMaxAge Text
header Either String Integer
-> (String -> MaybeT m Integer) -> MaybeT m Integer
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` \String
err -> Text -> MaybeT m ()
logWarn (String -> Text
T.pack String
err) MaybeT m () -> MaybeT m Integer -> MaybeT m Integer
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> MaybeT m Integer
forall (f :: * -> *) a. Alternative f => f a
empty
      NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (Integer -> NominalDiffTime
forall a. Num a => Integer -> a
fromInteger Integer
duration) (UTCTime -> UTCTime) -> MaybeT m UTCTime -> MaybeT m UTCTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime -> MaybeT m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    timeFromExpires :: HashMap Text Text -> MaybeT m UTCTime
timeFromExpires HashMap Text Text
headers = do
      Text
header <- Maybe Text -> MaybeT m Text
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t a -> f a
afold (Maybe Text -> MaybeT m Text) -> Maybe Text -> MaybeT m Text
forall a b. (a -> b) -> a -> b
$ Text -> HashMap Text Text -> Maybe Text
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
Map.lookup Text
"Expires" HashMap Text Text
headers
      Text -> Either String UTCTime
forall (m :: * -> *). MonadError String m => Text -> m UTCTime
parseExpirationTime Text
header Either String UTCTime
-> (String -> MaybeT m UTCTime) -> MaybeT m UTCTime
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` \String
err -> Text -> MaybeT m ()
logWarn (String -> Text
T.pack String
err) MaybeT m () -> MaybeT m UTCTime -> MaybeT m UTCTime
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> MaybeT m UTCTime
forall (f :: * -> *) a. Alternative f => f a
empty