{-# LANGUAGE TemplateHaskell #-}

-- |
-- Module      : Hasura.Server.Auth.JWT
-- Description : Implements JWT Configuration and Validation Logic.
-- Copyright   : Hasura
--
-- This module implements the bulk of Hasura's JWT capabilities and interactions.
-- Its main point of non-testing invocation is `Hasura.Server.Auth`.
--
-- It exports both `processJwt` and `processJwt_` with `processJwt_` being the
-- majority of the implementation with the JWT Token processing function
-- passed in as an argument in order to enable mocking in test-code.
--
-- In `processJwt_`, prior to validation of the token, first the token locations
-- and issuers are reconciled. Locations are either specified as auth or
-- cookie (with cookie name) or assumed to be auth. Issuers can be omitted or
-- specified, where an omitted configured issuer can match any issuer specified by
-- a request.
--
-- If none match, then this is considered an no-auth request, if one matches,
-- then normal token auth is performed, and if multiple match, then this is
-- considered an ambiguity error.
module Hasura.Server.Auth.JWT
  ( processJwt,
    RawJWT,
    StringOrURI (..),
    JWTConfig (..),
    JWTCtx (..),
    Jose.JWKSet (..),
    JWTClaimsFormat (..),
    JWTClaims (..),
    JwkFetchError (..),
    JWTHeader (..),
    JWTNamespace (..),
    JWTCustomClaimsMapDefaultRole,
    JWTCustomClaimsMapAllowedRoles,
    JWTCustomClaimsMapValue,
    ClaimsMap,
    updateJwkRef,
    jwkRefreshCtrl,
    defaultClaimsFormat,
    defaultClaimsNamespace,

    -- * Exposed for testing
    processJwt_,
    tokenIssuer,
    allowedRolesClaim,
    defaultRoleClaim,
    parseClaimsMap,
    JWTCustomClaimsMapValueG (..),
    JWTCustomClaimsMap (..),
    determineJwkExpiryLifetime,
  )
where

import Control.Concurrent.Extended qualified as C
import Control.Exception.Lifted (try)
import Control.Lens
import Control.Monad.Trans.Control (MonadBaseControl)
import Crypto.JWT qualified as Jose
import Data.Aeson (JSONPath)
import Data.Aeson qualified as J
import Data.Aeson.Casing qualified as J
import Data.Aeson.Key qualified as K
import Data.Aeson.KeyMap qualified as KM
import Data.Aeson.TH qualified as J
import Data.ByteArray.Encoding qualified as BAE
import Data.ByteString.Char8 qualified as BC
import Data.ByteString.Internal qualified as B
import Data.ByteString.Lazy qualified as BL
import Data.ByteString.Lazy.Char8 qualified as BLC
import Data.CaseInsensitive qualified as CI
import Data.HashMap.Strict qualified as HM
import Data.Hashable
import Data.IORef (IORef, readIORef, writeIORef)
import Data.Map.Strict qualified as M
import Data.Parser.CacheControl
import Data.Parser.Expires
import Data.Parser.JSONPath (encodeJSONPath, parseJSONPath)
import Data.Text qualified as T
import Data.Text.Encoding qualified as T
import Data.Time.Clock
  ( NominalDiffTime,
    UTCTime,
    diffUTCTime,
    getCurrentTime,
  )
import GHC.AssertNF.CPP
import Hasura.Base.Error
import Hasura.HTTP
import Hasura.Logging (Hasura, LogLevel (..), Logger (..))
import Hasura.Prelude
import Hasura.Server.Auth.JWT.Internal (parseEdDSAKey, parseHmacKey, parseRsaKey)
import Hasura.Server.Auth.JWT.Logging
import Hasura.Server.Utils
  ( executeJSONPath,
    getRequestHeader,
    isSessionVariable,
    userRoleHeader,
  )
import Hasura.Session
import Hasura.Tracing qualified as Tracing
import Network.HTTP.Client.Transformable qualified as HTTP
import Network.HTTP.Types as N
import Network.URI (URI)
import Network.Wreq qualified as Wreq
import Web.Spock.Internal.Cookies qualified as Spock

newtype RawJWT = RawJWT BL.ByteString

data JWTClaimsFormat
  = JCFJson
  | JCFStringifiedJson
  deriving (Int -> JWTClaimsFormat -> ShowS
[JWTClaimsFormat] -> ShowS
JWTClaimsFormat -> String
(Int -> JWTClaimsFormat -> ShowS)
-> (JWTClaimsFormat -> String)
-> ([JWTClaimsFormat] -> ShowS)
-> Show JWTClaimsFormat
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTClaimsFormat] -> ShowS
$cshowList :: [JWTClaimsFormat] -> ShowS
show :: JWTClaimsFormat -> String
$cshow :: JWTClaimsFormat -> String
showsPrec :: Int -> JWTClaimsFormat -> ShowS
$cshowsPrec :: Int -> JWTClaimsFormat -> ShowS
Show, JWTClaimsFormat -> JWTClaimsFormat -> Bool
(JWTClaimsFormat -> JWTClaimsFormat -> Bool)
-> (JWTClaimsFormat -> JWTClaimsFormat -> Bool)
-> Eq JWTClaimsFormat
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTClaimsFormat -> JWTClaimsFormat -> Bool
$c/= :: JWTClaimsFormat -> JWTClaimsFormat -> Bool
== :: JWTClaimsFormat -> JWTClaimsFormat -> Bool
$c== :: JWTClaimsFormat -> JWTClaimsFormat -> Bool
Eq)

$( J.deriveJSON
     J.defaultOptions
       { J.sumEncoding = J.ObjectWithSingleField,
         J.constructorTagModifier = J.snakeCase . drop 3
       }
     ''JWTClaimsFormat
 )

data JWTHeader
  = JHAuthorization
  | JHCookie Text -- cookie name
  deriving (Int -> JWTHeader -> ShowS
[JWTHeader] -> ShowS
JWTHeader -> String
(Int -> JWTHeader -> ShowS)
-> (JWTHeader -> String)
-> ([JWTHeader] -> ShowS)
-> Show JWTHeader
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTHeader] -> ShowS
$cshowList :: [JWTHeader] -> ShowS
show :: JWTHeader -> String
$cshow :: JWTHeader -> String
showsPrec :: Int -> JWTHeader -> ShowS
$cshowsPrec :: Int -> JWTHeader -> ShowS
Show, JWTHeader -> JWTHeader -> Bool
(JWTHeader -> JWTHeader -> Bool)
-> (JWTHeader -> JWTHeader -> Bool) -> Eq JWTHeader
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTHeader -> JWTHeader -> Bool
$c/= :: JWTHeader -> JWTHeader -> Bool
== :: JWTHeader -> JWTHeader -> Bool
$c== :: JWTHeader -> JWTHeader -> Bool
Eq, (forall x. JWTHeader -> Rep JWTHeader x)
-> (forall x. Rep JWTHeader x -> JWTHeader) -> Generic JWTHeader
forall x. Rep JWTHeader x -> JWTHeader
forall x. JWTHeader -> Rep JWTHeader x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep JWTHeader x -> JWTHeader
$cfrom :: forall x. JWTHeader -> Rep JWTHeader x
Generic)

instance Hashable JWTHeader

instance J.FromJSON JWTHeader where
  parseJSON :: Value -> Parser JWTHeader
parseJSON = String -> (Object -> Parser JWTHeader) -> Value -> Parser JWTHeader
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"JWTHeader" ((Object -> Parser JWTHeader) -> Value -> Parser JWTHeader)
-> (Object -> Parser JWTHeader) -> Value -> Parser JWTHeader
forall a b. (a -> b) -> a -> b
$ \Object
o -> do
    CI Text
hdrType <- Object
o Object -> Key -> Parser Text
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"type" Parser Text -> (Text -> CI Text) -> Parser (CI Text)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> FoldCase Text => Text -> CI Text
forall s. FoldCase s => s -> CI s
CI.mk @Text
    if
        | CI Text
hdrType CI Text -> CI Text -> Bool
forall a. Eq a => a -> a -> Bool
== CI Text
"Authorization" -> JWTHeader -> Parser JWTHeader
forall (f :: * -> *) a. Applicative f => a -> f a
pure JWTHeader
JHAuthorization
        | CI Text
hdrType CI Text -> CI Text -> Bool
forall a. Eq a => a -> a -> Bool
== CI Text
"Cookie" -> Text -> JWTHeader
JHCookie (Text -> JWTHeader) -> Parser Text -> Parser JWTHeader
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser Text
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"name"
        | Bool
otherwise -> String -> Parser JWTHeader
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"expected 'type' is 'Authorization' or 'Cookie'"

instance J.ToJSON JWTHeader where
  toJSON :: JWTHeader -> Value
toJSON JWTHeader
JHAuthorization = [Pair] -> Value
J.object [Key
"type" Key -> String -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= (String
"Authorization" :: String)]
  toJSON (JHCookie Text
name) =
    [Pair] -> Value
J.object
      [ Key
"type" Key -> String -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= (String
"Cookie" :: String),
        Key
"name" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Text
name
      ]

defaultClaimsFormat :: JWTClaimsFormat
defaultClaimsFormat :: JWTClaimsFormat
defaultClaimsFormat = JWTClaimsFormat
JCFJson

allowedRolesClaim :: SessionVariable
allowedRolesClaim :: SessionVariable
allowedRolesClaim = Text -> SessionVariable
mkSessionVariable Text
"x-hasura-allowed-roles"

defaultRoleClaim :: SessionVariable
defaultRoleClaim :: SessionVariable
defaultRoleClaim = Text -> SessionVariable
mkSessionVariable Text
"x-hasura-default-role"

defaultClaimsNamespace :: Text
defaultClaimsNamespace :: Text
defaultClaimsNamespace = Text
"https://hasura.io/jwt/claims"

-- | 'JWTCustomClaimsMapValueG' is used to represent a single value of
-- the 'JWTCustomClaimsMap'. A 'JWTCustomClaimsMapValueG' can either be
-- an JSON object or the literal value of the claim. If the value is an
-- JSON object, then it should contain a key `path`, which is the JSON path
-- to the claim value in the JWT token. There's also an option to specify a
-- default value in the map via the 'default' key, which will be used
-- when a peek at the JWT token using the JSON path fails (key does not exist).
data JWTCustomClaimsMapValueG v
  = -- | JSONPath to the key in the claims map, in case
    -- the key doesn't exist in the claims map then the default
    -- value will be used (if provided)
    JWTCustomClaimsMapJSONPath !J.JSONPath !(Maybe v)
  | JWTCustomClaimsMapStatic !v
  deriving (Int -> JWTCustomClaimsMapValueG v -> ShowS
[JWTCustomClaimsMapValueG v] -> ShowS
JWTCustomClaimsMapValueG v -> String
(Int -> JWTCustomClaimsMapValueG v -> ShowS)
-> (JWTCustomClaimsMapValueG v -> String)
-> ([JWTCustomClaimsMapValueG v] -> ShowS)
-> Show (JWTCustomClaimsMapValueG v)
forall v. Show v => Int -> JWTCustomClaimsMapValueG v -> ShowS
forall v. Show v => [JWTCustomClaimsMapValueG v] -> ShowS
forall v. Show v => JWTCustomClaimsMapValueG v -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTCustomClaimsMapValueG v] -> ShowS
$cshowList :: forall v. Show v => [JWTCustomClaimsMapValueG v] -> ShowS
show :: JWTCustomClaimsMapValueG v -> String
$cshow :: forall v. Show v => JWTCustomClaimsMapValueG v -> String
showsPrec :: Int -> JWTCustomClaimsMapValueG v -> ShowS
$cshowsPrec :: forall v. Show v => Int -> JWTCustomClaimsMapValueG v -> ShowS
Show, JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
(JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool)
-> (JWTCustomClaimsMapValueG v
    -> JWTCustomClaimsMapValueG v -> Bool)
-> Eq (JWTCustomClaimsMapValueG v)
forall v.
Eq v =>
JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
$c/= :: forall v.
Eq v =>
JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
== :: JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
$c== :: forall v.
Eq v =>
JWTCustomClaimsMapValueG v -> JWTCustomClaimsMapValueG v -> Bool
Eq, a -> JWTCustomClaimsMapValueG b -> JWTCustomClaimsMapValueG a
(a -> b)
-> JWTCustomClaimsMapValueG a -> JWTCustomClaimsMapValueG b
(forall a b.
 (a -> b)
 -> JWTCustomClaimsMapValueG a -> JWTCustomClaimsMapValueG b)
-> (forall a b.
    a -> JWTCustomClaimsMapValueG b -> JWTCustomClaimsMapValueG a)
-> Functor JWTCustomClaimsMapValueG
forall a b.
a -> JWTCustomClaimsMapValueG b -> JWTCustomClaimsMapValueG a
forall a b.
(a -> b)
-> JWTCustomClaimsMapValueG a -> JWTCustomClaimsMapValueG b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> JWTCustomClaimsMapValueG b -> JWTCustomClaimsMapValueG a
$c<$ :: forall a b.
a -> JWTCustomClaimsMapValueG b -> JWTCustomClaimsMapValueG a
fmap :: (a -> b)
-> JWTCustomClaimsMapValueG a -> JWTCustomClaimsMapValueG b
$cfmap :: forall a b.
(a -> b)
-> JWTCustomClaimsMapValueG a -> JWTCustomClaimsMapValueG b
Functor, JWTCustomClaimsMapValueG a -> Bool
(a -> m) -> JWTCustomClaimsMapValueG a -> m
(a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
(forall m. Monoid m => JWTCustomClaimsMapValueG m -> m)
-> (forall m a.
    Monoid m =>
    (a -> m) -> JWTCustomClaimsMapValueG a -> m)
-> (forall m a.
    Monoid m =>
    (a -> m) -> JWTCustomClaimsMapValueG a -> m)
-> (forall a b.
    (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b)
-> (forall a b.
    (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b)
-> (forall b a.
    (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b)
-> (forall b a.
    (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b)
-> (forall a. (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a)
-> (forall a. (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a)
-> (forall a. JWTCustomClaimsMapValueG a -> [a])
-> (forall a. JWTCustomClaimsMapValueG a -> Bool)
-> (forall a. JWTCustomClaimsMapValueG a -> Int)
-> (forall a. Eq a => a -> JWTCustomClaimsMapValueG a -> Bool)
-> (forall a. Ord a => JWTCustomClaimsMapValueG a -> a)
-> (forall a. Ord a => JWTCustomClaimsMapValueG a -> a)
-> (forall a. Num a => JWTCustomClaimsMapValueG a -> a)
-> (forall a. Num a => JWTCustomClaimsMapValueG a -> a)
-> Foldable JWTCustomClaimsMapValueG
forall a. Eq a => a -> JWTCustomClaimsMapValueG a -> Bool
forall a. Num a => JWTCustomClaimsMapValueG a -> a
forall a. Ord a => JWTCustomClaimsMapValueG a -> a
forall m. Monoid m => JWTCustomClaimsMapValueG m -> m
forall a. JWTCustomClaimsMapValueG a -> Bool
forall a. JWTCustomClaimsMapValueG a -> Int
forall a. JWTCustomClaimsMapValueG a -> [a]
forall a. (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a
forall m a. Monoid m => (a -> m) -> JWTCustomClaimsMapValueG a -> m
forall b a. (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b
forall a b. (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: JWTCustomClaimsMapValueG a -> a
$cproduct :: forall a. Num a => JWTCustomClaimsMapValueG a -> a
sum :: JWTCustomClaimsMapValueG a -> a
$csum :: forall a. Num a => JWTCustomClaimsMapValueG a -> a
minimum :: JWTCustomClaimsMapValueG a -> a
$cminimum :: forall a. Ord a => JWTCustomClaimsMapValueG a -> a
maximum :: JWTCustomClaimsMapValueG a -> a
$cmaximum :: forall a. Ord a => JWTCustomClaimsMapValueG a -> a
elem :: a -> JWTCustomClaimsMapValueG a -> Bool
$celem :: forall a. Eq a => a -> JWTCustomClaimsMapValueG a -> Bool
length :: JWTCustomClaimsMapValueG a -> Int
$clength :: forall a. JWTCustomClaimsMapValueG a -> Int
null :: JWTCustomClaimsMapValueG a -> Bool
$cnull :: forall a. JWTCustomClaimsMapValueG a -> Bool
toList :: JWTCustomClaimsMapValueG a -> [a]
$ctoList :: forall a. JWTCustomClaimsMapValueG a -> [a]
foldl1 :: (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a
foldr1 :: (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> JWTCustomClaimsMapValueG a -> a
foldl' :: (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b
foldl :: (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> JWTCustomClaimsMapValueG a -> b
foldr' :: (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
foldr :: (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> JWTCustomClaimsMapValueG a -> b
foldMap' :: (a -> m) -> JWTCustomClaimsMapValueG a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> JWTCustomClaimsMapValueG a -> m
foldMap :: (a -> m) -> JWTCustomClaimsMapValueG a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> JWTCustomClaimsMapValueG a -> m
fold :: JWTCustomClaimsMapValueG m -> m
$cfold :: forall m. Monoid m => JWTCustomClaimsMapValueG m -> m
Foldable, Functor JWTCustomClaimsMapValueG
Foldable JWTCustomClaimsMapValueG
Functor JWTCustomClaimsMapValueG
-> Foldable JWTCustomClaimsMapValueG
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b)
    -> JWTCustomClaimsMapValueG a -> f (JWTCustomClaimsMapValueG b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    JWTCustomClaimsMapValueG (f a) -> f (JWTCustomClaimsMapValueG a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b)
    -> JWTCustomClaimsMapValueG a -> m (JWTCustomClaimsMapValueG b))
-> (forall (m :: * -> *) a.
    Monad m =>
    JWTCustomClaimsMapValueG (m a) -> m (JWTCustomClaimsMapValueG a))
-> Traversable JWTCustomClaimsMapValueG
(a -> f b)
-> JWTCustomClaimsMapValueG a -> f (JWTCustomClaimsMapValueG b)
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a.
Monad m =>
JWTCustomClaimsMapValueG (m a) -> m (JWTCustomClaimsMapValueG a)
forall (f :: * -> *) a.
Applicative f =>
JWTCustomClaimsMapValueG (f a) -> f (JWTCustomClaimsMapValueG a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b)
-> JWTCustomClaimsMapValueG a -> m (JWTCustomClaimsMapValueG b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b)
-> JWTCustomClaimsMapValueG a -> f (JWTCustomClaimsMapValueG b)
sequence :: JWTCustomClaimsMapValueG (m a) -> m (JWTCustomClaimsMapValueG a)
$csequence :: forall (m :: * -> *) a.
Monad m =>
JWTCustomClaimsMapValueG (m a) -> m (JWTCustomClaimsMapValueG a)
mapM :: (a -> m b)
-> JWTCustomClaimsMapValueG a -> m (JWTCustomClaimsMapValueG b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b)
-> JWTCustomClaimsMapValueG a -> m (JWTCustomClaimsMapValueG b)
sequenceA :: JWTCustomClaimsMapValueG (f a) -> f (JWTCustomClaimsMapValueG a)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
JWTCustomClaimsMapValueG (f a) -> f (JWTCustomClaimsMapValueG a)
traverse :: (a -> f b)
-> JWTCustomClaimsMapValueG a -> f (JWTCustomClaimsMapValueG b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b)
-> JWTCustomClaimsMapValueG a -> f (JWTCustomClaimsMapValueG b)
$cp2Traversable :: Foldable JWTCustomClaimsMapValueG
$cp1Traversable :: Functor JWTCustomClaimsMapValueG
Traversable)

instance (J.FromJSON v) => J.FromJSON (JWTCustomClaimsMapValueG v) where
  parseJSON :: Value -> Parser (JWTCustomClaimsMapValueG v)
parseJSON (J.Object Object
obj) = do
    JSONPath
path <- Object
obj Object -> Key -> Parser Text
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"path" Parser Text -> (Text -> Parser JSONPath) -> Parser JSONPath
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((Text -> Parser JSONPath)
-> (JSONPath -> Parser JSONPath)
-> Either Text JSONPath
-> Parser JSONPath
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Parser JSONPath
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser JSONPath)
-> (Text -> String) -> Text -> Parser JSONPath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack) JSONPath -> Parser JSONPath
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Text JSONPath -> Parser JSONPath)
-> (Text -> Either Text JSONPath) -> Text -> Parser JSONPath
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Either Text JSONPath
parseJSONPath)
    Maybe v
defaultVal <- Object
obj Object -> Key -> Parser (Maybe v)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"default" Parser (Maybe v)
-> (Maybe v -> Parser (Maybe v)) -> Parser (Maybe v)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe v -> Parser (Maybe v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    JWTCustomClaimsMapValueG v -> Parser (JWTCustomClaimsMapValueG v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JWTCustomClaimsMapValueG v -> Parser (JWTCustomClaimsMapValueG v))
-> JWTCustomClaimsMapValueG v
-> Parser (JWTCustomClaimsMapValueG v)
forall a b. (a -> b) -> a -> b
$ JSONPath -> Maybe v -> JWTCustomClaimsMapValueG v
forall v. JSONPath -> Maybe v -> JWTCustomClaimsMapValueG v
JWTCustomClaimsMapJSONPath JSONPath
path Maybe v
defaultVal
  parseJSON Value
v = v -> JWTCustomClaimsMapValueG v
forall v. v -> JWTCustomClaimsMapValueG v
JWTCustomClaimsMapStatic (v -> JWTCustomClaimsMapValueG v)
-> Parser v -> Parser (JWTCustomClaimsMapValueG v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser v
forall a. FromJSON a => Value -> Parser a
J.parseJSON Value
v

instance (J.ToJSON v) => J.ToJSON (JWTCustomClaimsMapValueG v) where
  toJSON :: JWTCustomClaimsMapValueG v -> Value
toJSON (JWTCustomClaimsMapJSONPath JSONPath
jsonPath Maybe v
mDefVal) =
    [Pair] -> Value
J.object ([Pair] -> Value) -> [Pair] -> Value
forall a b. (a -> b) -> a -> b
$
      [Key
"path" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= JSONPath -> Text
encodeJSONPath JSONPath
jsonPath]
        [Pair] -> [Pair] -> [Pair]
forall a. Semigroup a => a -> a -> a
<> [Key
"default" Key -> v -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= v
defVal | Just v
defVal <- [Maybe v
mDefVal]]
  toJSON (JWTCustomClaimsMapStatic v
v) = v -> Value
forall a. ToJSON a => a -> Value
J.toJSON v
v

type JWTCustomClaimsMapDefaultRole = JWTCustomClaimsMapValueG RoleName

type JWTCustomClaimsMapAllowedRoles = JWTCustomClaimsMapValueG [RoleName]

-- Used to store other session variables like `x-hasura-user-id`
type JWTCustomClaimsMapValue = JWTCustomClaimsMapValueG SessionVariableValue

type CustomClaimsMap = HM.HashMap SessionVariable JWTCustomClaimsMapValue

-- | JWTClaimsMap is an option to provide a custom JWT claims map.
-- The JWTClaimsMap should be specified in the `HASURA_GRAPHQL_JWT_SECRET`
-- in the `claims_map`. The JWTClaimsMap, if specified, requires two
-- mandatory fields, namely, `x-hasura-allowed-roles` and the
-- `x-hasura-default-role`, other claims may also be provided in the claims map.
data JWTCustomClaimsMap = JWTCustomClaimsMap
  { JWTCustomClaimsMap -> JWTCustomClaimsMapDefaultRole
jcmDefaultRole :: !JWTCustomClaimsMapDefaultRole,
    JWTCustomClaimsMap -> JWTCustomClaimsMapAllowedRoles
jcmAllowedRoles :: !JWTCustomClaimsMapAllowedRoles,
    JWTCustomClaimsMap -> CustomClaimsMap
jcmCustomClaims :: !CustomClaimsMap
  }
  deriving (Int -> JWTCustomClaimsMap -> ShowS
[JWTCustomClaimsMap] -> ShowS
JWTCustomClaimsMap -> String
(Int -> JWTCustomClaimsMap -> ShowS)
-> (JWTCustomClaimsMap -> String)
-> ([JWTCustomClaimsMap] -> ShowS)
-> Show JWTCustomClaimsMap
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTCustomClaimsMap] -> ShowS
$cshowList :: [JWTCustomClaimsMap] -> ShowS
show :: JWTCustomClaimsMap -> String
$cshow :: JWTCustomClaimsMap -> String
showsPrec :: Int -> JWTCustomClaimsMap -> ShowS
$cshowsPrec :: Int -> JWTCustomClaimsMap -> ShowS
Show, JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool
(JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool)
-> (JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool)
-> Eq JWTCustomClaimsMap
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool
$c/= :: JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool
== :: JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool
$c== :: JWTCustomClaimsMap -> JWTCustomClaimsMap -> Bool
Eq)

instance J.ToJSON JWTCustomClaimsMap where
  toJSON :: JWTCustomClaimsMap -> Value
toJSON (JWTCustomClaimsMap JWTCustomClaimsMapDefaultRole
defaultRole JWTCustomClaimsMapAllowedRoles
allowedRoles CustomClaimsMap
customClaims) =
    Object -> Value
J.Object (Object -> Value) -> Object -> Value
forall a b. (a -> b) -> a -> b
$
      [Pair] -> Object
forall v. [(Key, v)] -> KeyMap v
KM.fromList ([Pair] -> Object) -> [Pair] -> Object
forall a b. (a -> b) -> a -> b
$
        ((SessionVariable, Value) -> Pair)
-> [(SessionVariable, Value)] -> [Pair]
forall a b. (a -> b) -> [a] -> [b]
map ((SessionVariable -> Key) -> (SessionVariable, Value) -> Pair
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (Text -> Key
K.fromText (Text -> Key)
-> (SessionVariable -> Text) -> SessionVariable -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionVariable -> Text
sessionVariableToText)) ([(SessionVariable, Value)] -> [Pair])
-> [(SessionVariable, Value)] -> [Pair]
forall a b. (a -> b) -> a -> b
$
          [ (SessionVariable
defaultRoleClaim, JWTCustomClaimsMapDefaultRole -> Value
forall a. ToJSON a => a -> Value
J.toJSON JWTCustomClaimsMapDefaultRole
defaultRole),
            (SessionVariable
allowedRolesClaim, JWTCustomClaimsMapAllowedRoles -> Value
forall a. ToJSON a => a -> Value
J.toJSON JWTCustomClaimsMapAllowedRoles
allowedRoles)
          ]
            [(SessionVariable, Value)]
-> [(SessionVariable, Value)] -> [(SessionVariable, Value)]
forall a. Semigroup a => a -> a -> a
<> ((SessionVariable, JWTCustomClaimsMapValue)
 -> (SessionVariable, Value))
-> [(SessionVariable, JWTCustomClaimsMapValue)]
-> [(SessionVariable, Value)]
forall a b. (a -> b) -> [a] -> [b]
map ((JWTCustomClaimsMapValue -> Value)
-> (SessionVariable, JWTCustomClaimsMapValue)
-> (SessionVariable, Value)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second JWTCustomClaimsMapValue -> Value
forall a. ToJSON a => a -> Value
J.toJSON) (CustomClaimsMap -> [(SessionVariable, JWTCustomClaimsMapValue)]
forall k v. HashMap k v -> [(k, v)]
HM.toList CustomClaimsMap
customClaims)

instance J.FromJSON JWTCustomClaimsMap where
  parseJSON :: Value -> Parser JWTCustomClaimsMap
parseJSON = String
-> (Object -> Parser JWTCustomClaimsMap)
-> Value
-> Parser JWTCustomClaimsMap
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"JWTClaimsMap" ((Object -> Parser JWTCustomClaimsMap)
 -> Value -> Parser JWTCustomClaimsMap)
-> (Object -> Parser JWTCustomClaimsMap)
-> Value
-> Parser JWTCustomClaimsMap
forall a b. (a -> b) -> a -> b
$ \Object
obj -> do
    let withNotFoundError :: SessionVariable -> Parser Value
withNotFoundError SessionVariable
sessionVariable =
          let sessionVarText :: Text
sessionVarText = SessionVariable -> Text
sessionVariableToText SessionVariable
sessionVariable
              errorMsg :: String
errorMsg =
                Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
                  Text
sessionVarText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" is expected but not found"
           in Key -> Object -> Maybe Value
forall v. Key -> KeyMap v -> Maybe v
KM.lookup (Text -> Key
K.fromText Text
sessionVarText) Object
obj
                Maybe Value -> Parser Value -> Parser Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` String -> Parser Value
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
errorMsg

    JWTCustomClaimsMapAllowedRoles
allowedRoles <- SessionVariable -> Parser Value
withNotFoundError SessionVariable
allowedRolesClaim Parser Value
-> (Value -> Parser JWTCustomClaimsMapAllowedRoles)
-> Parser JWTCustomClaimsMapAllowedRoles
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> Parser JWTCustomClaimsMapAllowedRoles
forall a. FromJSON a => Value -> Parser a
J.parseJSON
    JWTCustomClaimsMapDefaultRole
defaultRole <- SessionVariable -> Parser Value
withNotFoundError SessionVariable
defaultRoleClaim Parser Value
-> (Value -> Parser JWTCustomClaimsMapDefaultRole)
-> Parser JWTCustomClaimsMapDefaultRole
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Value -> Parser JWTCustomClaimsMapDefaultRole
forall a. FromJSON a => Value -> Parser a
J.parseJSON
    let filteredClaims :: HashMap SessionVariable Value
filteredClaims =
          SessionVariable
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete SessionVariable
allowedRolesClaim (HashMap SessionVariable Value -> HashMap SessionVariable Value)
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall a b. (a -> b) -> a -> b
$
            SessionVariable
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete SessionVariable
defaultRoleClaim (HashMap SessionVariable Value -> HashMap SessionVariable Value)
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall a b. (a -> b) -> a -> b
$
              [(SessionVariable, Value)] -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList ([(SessionVariable, Value)] -> HashMap SessionVariable Value)
-> [(SessionVariable, Value)] -> HashMap SessionVariable Value
forall a b. (a -> b) -> a -> b
$
                (Pair -> (SessionVariable, Value))
-> [Pair] -> [(SessionVariable, Value)]
forall a b. (a -> b) -> [a] -> [b]
map ((Key -> SessionVariable) -> Pair -> (SessionVariable, Value)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (Text -> SessionVariable
mkSessionVariable (Text -> SessionVariable)
-> (Key -> Text) -> Key -> SessionVariable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Key -> Text
K.toText)) ([Pair] -> [(SessionVariable, Value)])
-> [Pair] -> [(SessionVariable, Value)]
forall a b. (a -> b) -> a -> b
$
                  Object -> [Pair]
forall v. KeyMap v -> [(Key, v)]
KM.toList Object
obj
    CustomClaimsMap
customClaims <- ((SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
 -> HashMap SessionVariable Value -> Parser CustomClaimsMap)
-> HashMap SessionVariable Value
-> (SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
-> Parser CustomClaimsMap
forall a b c. (a -> b -> c) -> b -> a -> c
flip (SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
-> HashMap SessionVariable Value -> Parser CustomClaimsMap
forall (f :: * -> *) k v1 v2.
Applicative f =>
(k -> v1 -> f v2) -> HashMap k v1 -> f (HashMap k v2)
HM.traverseWithKey HashMap SessionVariable Value
filteredClaims ((SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
 -> Parser CustomClaimsMap)
-> (SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
-> Parser CustomClaimsMap
forall a b. (a -> b) -> a -> b
$ (Value -> Parser JWTCustomClaimsMapValue)
-> SessionVariable -> Value -> Parser JWTCustomClaimsMapValue
forall a b. a -> b -> a
const ((Value -> Parser JWTCustomClaimsMapValue)
 -> SessionVariable -> Value -> Parser JWTCustomClaimsMapValue)
-> (Value -> Parser JWTCustomClaimsMapValue)
-> SessionVariable
-> Value
-> Parser JWTCustomClaimsMapValue
forall a b. (a -> b) -> a -> b
$ Value -> Parser JWTCustomClaimsMapValue
forall a. FromJSON a => Value -> Parser a
J.parseJSON
    JWTCustomClaimsMap -> Parser JWTCustomClaimsMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JWTCustomClaimsMap -> Parser JWTCustomClaimsMap)
-> JWTCustomClaimsMap -> Parser JWTCustomClaimsMap
forall a b. (a -> b) -> a -> b
$ JWTCustomClaimsMapDefaultRole
-> JWTCustomClaimsMapAllowedRoles
-> CustomClaimsMap
-> JWTCustomClaimsMap
JWTCustomClaimsMap JWTCustomClaimsMapDefaultRole
defaultRole JWTCustomClaimsMapAllowedRoles
allowedRoles CustomClaimsMap
customClaims

-- | JWTNamespace is used to locate the claims map within the JWT token.
-- The location can be either provided via a JSON path or the name of the
-- key in the JWT token.
data JWTNamespace
  = ClaimNsPath JSONPath
  | ClaimNs Text
  deriving (Int -> JWTNamespace -> ShowS
[JWTNamespace] -> ShowS
JWTNamespace -> String
(Int -> JWTNamespace -> ShowS)
-> (JWTNamespace -> String)
-> ([JWTNamespace] -> ShowS)
-> Show JWTNamespace
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTNamespace] -> ShowS
$cshowList :: [JWTNamespace] -> ShowS
show :: JWTNamespace -> String
$cshow :: JWTNamespace -> String
showsPrec :: Int -> JWTNamespace -> ShowS
$cshowsPrec :: Int -> JWTNamespace -> ShowS
Show, JWTNamespace -> JWTNamespace -> Bool
(JWTNamespace -> JWTNamespace -> Bool)
-> (JWTNamespace -> JWTNamespace -> Bool) -> Eq JWTNamespace
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTNamespace -> JWTNamespace -> Bool
$c/= :: JWTNamespace -> JWTNamespace -> Bool
== :: JWTNamespace -> JWTNamespace -> Bool
$c== :: JWTNamespace -> JWTNamespace -> Bool
Eq)

instance J.ToJSON JWTNamespace where
  toJSON :: JWTNamespace -> Value
toJSON (ClaimNsPath JSONPath
nsPath) = Text -> Value
J.String (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ JSONPath -> Text
encodeJSONPath JSONPath
nsPath
  toJSON (ClaimNs Text
ns) = Text -> Value
J.String Text
ns

data JWTClaims
  = JCNamespace !JWTNamespace !JWTClaimsFormat
  | JCMap !JWTCustomClaimsMap
  deriving (Int -> JWTClaims -> ShowS
[JWTClaims] -> ShowS
JWTClaims -> String
(Int -> JWTClaims -> ShowS)
-> (JWTClaims -> String)
-> ([JWTClaims] -> ShowS)
-> Show JWTClaims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTClaims] -> ShowS
$cshowList :: [JWTClaims] -> ShowS
show :: JWTClaims -> String
$cshow :: JWTClaims -> String
showsPrec :: Int -> JWTClaims -> ShowS
$cshowsPrec :: Int -> JWTClaims -> ShowS
Show, JWTClaims -> JWTClaims -> Bool
(JWTClaims -> JWTClaims -> Bool)
-> (JWTClaims -> JWTClaims -> Bool) -> Eq JWTClaims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTClaims -> JWTClaims -> Bool
$c/= :: JWTClaims -> JWTClaims -> Bool
== :: JWTClaims -> JWTClaims -> Bool
$c== :: JWTClaims -> JWTClaims -> Bool
Eq)

-- | Hashable Wrapper for constructing a HashMap of JWTConfigs
newtype StringOrURI = StringOrURI {StringOrURI -> StringOrURI
unStringOrURI :: Jose.StringOrURI}
  deriving newtype (Int -> StringOrURI -> ShowS
[StringOrURI] -> ShowS
StringOrURI -> String
(Int -> StringOrURI -> ShowS)
-> (StringOrURI -> String)
-> ([StringOrURI] -> ShowS)
-> Show StringOrURI
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StringOrURI] -> ShowS
$cshowList :: [StringOrURI] -> ShowS
show :: StringOrURI -> String
$cshow :: StringOrURI -> String
showsPrec :: Int -> StringOrURI -> ShowS
$cshowsPrec :: Int -> StringOrURI -> ShowS
Show, StringOrURI -> StringOrURI -> Bool
(StringOrURI -> StringOrURI -> Bool)
-> (StringOrURI -> StringOrURI -> Bool) -> Eq StringOrURI
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StringOrURI -> StringOrURI -> Bool
$c/= :: StringOrURI -> StringOrURI -> Bool
== :: StringOrURI -> StringOrURI -> Bool
$c== :: StringOrURI -> StringOrURI -> Bool
Eq, [StringOrURI] -> Value
[StringOrURI] -> Encoding
StringOrURI -> Value
StringOrURI -> Encoding
(StringOrURI -> Value)
-> (StringOrURI -> Encoding)
-> ([StringOrURI] -> Value)
-> ([StringOrURI] -> Encoding)
-> ToJSON StringOrURI
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
toEncodingList :: [StringOrURI] -> Encoding
$ctoEncodingList :: [StringOrURI] -> Encoding
toJSONList :: [StringOrURI] -> Value
$ctoJSONList :: [StringOrURI] -> Value
toEncoding :: StringOrURI -> Encoding
$ctoEncoding :: StringOrURI -> Encoding
toJSON :: StringOrURI -> Value
$ctoJSON :: StringOrURI -> Value
J.ToJSON, Value -> Parser [StringOrURI]
Value -> Parser StringOrURI
(Value -> Parser StringOrURI)
-> (Value -> Parser [StringOrURI]) -> FromJSON StringOrURI
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
parseJSONList :: Value -> Parser [StringOrURI]
$cparseJSONList :: Value -> Parser [StringOrURI]
parseJSON :: Value -> Parser StringOrURI
$cparseJSON :: Value -> Parser StringOrURI
J.FromJSON)

instance J.ToJSONKey StringOrURI

instance J.FromJSONKey StringOrURI

instance J.ToJSONKey (Maybe StringOrURI)

instance J.FromJSONKey (Maybe StringOrURI)

instance Hashable StringOrURI where
  hashWithSalt :: Int -> StringOrURI -> Int
hashWithSalt Int
i = Int -> ByteString -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
i (ByteString -> Int)
-> (StringOrURI -> ByteString) -> StringOrURI -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StringOrURI -> ByteString
forall a. ToJSON a => a -> ByteString
J.encode

-- | The JWT configuration we got from the user.
data JWTConfig = JWTConfig
  { JWTConfig -> Either JWK URI
jcKeyOrUrl :: !(Either Jose.JWK URI),
    JWTConfig -> Maybe Audience
jcAudience :: !(Maybe Jose.Audience),
    JWTConfig -> Maybe StringOrURI
jcIssuer :: !(Maybe Jose.StringOrURI),
    JWTConfig -> JWTClaims
jcClaims :: !JWTClaims,
    JWTConfig -> Maybe NominalDiffTime
jcAllowedSkew :: !(Maybe NominalDiffTime),
    JWTConfig -> Maybe JWTHeader
jcHeader :: !(Maybe JWTHeader)
  }
  deriving (Int -> JWTConfig -> ShowS
[JWTConfig] -> ShowS
JWTConfig -> String
(Int -> JWTConfig -> ShowS)
-> (JWTConfig -> String)
-> ([JWTConfig] -> ShowS)
-> Show JWTConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTConfig] -> ShowS
$cshowList :: [JWTConfig] -> ShowS
show :: JWTConfig -> String
$cshow :: JWTConfig -> String
showsPrec :: Int -> JWTConfig -> ShowS
$cshowsPrec :: Int -> JWTConfig -> ShowS
Show, JWTConfig -> JWTConfig -> Bool
(JWTConfig -> JWTConfig -> Bool)
-> (JWTConfig -> JWTConfig -> Bool) -> Eq JWTConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTConfig -> JWTConfig -> Bool
$c/= :: JWTConfig -> JWTConfig -> Bool
== :: JWTConfig -> JWTConfig -> Bool
$c== :: JWTConfig -> JWTConfig -> Bool
Eq)

-- | The validated runtime JWT configuration returned by 'mkJwtCtx' in 'setupAuthMode'.
--
-- This is also evidence that the 'jwkRefreshCtrl' thread is running, if an
-- expiration schedule could be determined.
data JWTCtx = JWTCtx
  { -- | This needs to be a mutable variable for 'updateJwkRef'.
    JWTCtx -> IORef JWKSet
jcxKey :: !(IORef Jose.JWKSet),
    JWTCtx -> Maybe Audience
jcxAudience :: !(Maybe Jose.Audience),
    JWTCtx -> Maybe StringOrURI
jcxIssuer :: !(Maybe Jose.StringOrURI),
    JWTCtx -> JWTClaims
jcxClaims :: !JWTClaims,
    JWTCtx -> Maybe NominalDiffTime
jcxAllowedSkew :: !(Maybe NominalDiffTime),
    JWTCtx -> JWTHeader
jcxHeader :: !JWTHeader
  }
  deriving (JWTCtx -> JWTCtx -> Bool
(JWTCtx -> JWTCtx -> Bool)
-> (JWTCtx -> JWTCtx -> Bool) -> Eq JWTCtx
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTCtx -> JWTCtx -> Bool
$c/= :: JWTCtx -> JWTCtx -> Bool
== :: JWTCtx -> JWTCtx -> Bool
$c== :: JWTCtx -> JWTCtx -> Bool
Eq)

instance Show JWTCtx where
  show :: JWTCtx -> String
show (JWTCtx IORef JWKSet
_ Maybe Audience
audM Maybe StringOrURI
iss JWTClaims
claims Maybe NominalDiffTime
allowedSkew JWTHeader
headers) =
    [String] -> String
forall a. Show a => a -> String
show [String
"<IORef JWKSet>", Maybe Audience -> String
forall a. Show a => a -> String
show Maybe Audience
audM, Maybe StringOrURI -> String
forall a. Show a => a -> String
show Maybe StringOrURI
iss, JWTClaims -> String
forall a. Show a => a -> String
show JWTClaims
claims, Maybe NominalDiffTime -> String
forall a. Show a => a -> String
show Maybe NominalDiffTime
allowedSkew, JWTHeader -> String
forall a. Show a => a -> String
show JWTHeader
headers]

data HasuraClaims = HasuraClaims
  { HasuraClaims -> [RoleName]
_cmAllowedRoles :: ![RoleName],
    HasuraClaims -> RoleName
_cmDefaultRole :: !RoleName
  }
  deriving (Int -> HasuraClaims -> ShowS
[HasuraClaims] -> ShowS
HasuraClaims -> String
(Int -> HasuraClaims -> ShowS)
-> (HasuraClaims -> String)
-> ([HasuraClaims] -> ShowS)
-> Show HasuraClaims
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HasuraClaims] -> ShowS
$cshowList :: [HasuraClaims] -> ShowS
show :: HasuraClaims -> String
$cshow :: HasuraClaims -> String
showsPrec :: Int -> HasuraClaims -> ShowS
$cshowsPrec :: Int -> HasuraClaims -> ShowS
Show, HasuraClaims -> HasuraClaims -> Bool
(HasuraClaims -> HasuraClaims -> Bool)
-> (HasuraClaims -> HasuraClaims -> Bool) -> Eq HasuraClaims
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HasuraClaims -> HasuraClaims -> Bool
$c/= :: HasuraClaims -> HasuraClaims -> Bool
== :: HasuraClaims -> HasuraClaims -> Bool
$c== :: HasuraClaims -> HasuraClaims -> Bool
Eq)

$(J.deriveJSON hasuraJSON ''HasuraClaims)

-- | An action that refreshes the JWK at intervals in an infinite loop.
jwkRefreshCtrl ::
  (MonadIO m, MonadBaseControl IO m, Tracing.HasReporter m) =>
  Logger Hasura ->
  HTTP.Manager ->
  URI ->
  IORef Jose.JWKSet ->
  DiffTime ->
  m void
jwkRefreshCtrl :: Logger Hasura
-> Manager -> URI -> IORef JWKSet -> DiffTime -> m void
jwkRefreshCtrl Logger Hasura
logger Manager
manager URI
url IORef JWKSet
ref DiffTime
time = do
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ DiffTime -> IO ()
C.sleep DiffTime
time
  m () -> m void
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m void) -> m () -> m void
forall a b. (a -> b) -> a -> b
$ Text -> TraceT m () -> m ()
forall (m :: * -> *) a.
(HasReporter m, MonadIO m) =>
Text -> TraceT m a -> m a
Tracing.runTraceT Text
"jwk refresh" do
    Either JwkFetchError (Maybe NominalDiffTime)
res <- ExceptT JwkFetchError (TraceT m) (Maybe NominalDiffTime)
-> TraceT m (Either JwkFetchError (Maybe NominalDiffTime))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT JwkFetchError (TraceT m) (Maybe NominalDiffTime)
 -> TraceT m (Either JwkFetchError (Maybe NominalDiffTime)))
-> ExceptT JwkFetchError (TraceT m) (Maybe NominalDiffTime)
-> TraceT m (Either JwkFetchError (Maybe NominalDiffTime))
forall a b. (a -> b) -> a -> b
$ Logger Hasura
-> Manager
-> URI
-> IORef JWKSet
-> ExceptT JwkFetchError (TraceT m) (Maybe NominalDiffTime)
forall (m :: * -> *).
(MonadIO m, MonadBaseControl IO m, MonadError JwkFetchError m,
 MonadTrace m) =>
Logger Hasura
-> Manager -> URI -> IORef JWKSet -> m (Maybe NominalDiffTime)
updateJwkRef Logger Hasura
logger Manager
manager URI
url IORef JWKSet
ref
    Maybe NominalDiffTime
mTime <- Either JwkFetchError (Maybe NominalDiffTime)
-> (JwkFetchError -> TraceT m (Maybe NominalDiffTime))
-> TraceT m (Maybe NominalDiffTime)
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft Either JwkFetchError (Maybe NominalDiffTime)
res (TraceT m (Maybe NominalDiffTime)
-> JwkFetchError -> TraceT m (Maybe NominalDiffTime)
forall a b. a -> b -> a
const (TraceT m (Maybe NominalDiffTime)
 -> JwkFetchError -> TraceT m (Maybe NominalDiffTime))
-> TraceT m (Maybe NominalDiffTime)
-> JwkFetchError
-> TraceT m (Maybe NominalDiffTime)
forall a b. (a -> b) -> a -> b
$ TraceT m ()
logNotice TraceT m ()
-> TraceT m (Maybe NominalDiffTime)
-> TraceT m (Maybe NominalDiffTime)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe NominalDiffTime -> TraceT m (Maybe NominalDiffTime)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe NominalDiffTime
forall a. Maybe a
Nothing)
    -- if can't parse time from header, defaults to 1 min
    -- and never use a smaller delay than one second to avoid a tight loop
    let delay :: DiffTime
delay = DiffTime -> DiffTime -> DiffTime
forall a. Ord a => a -> a -> a
max (Seconds -> DiffTime
seconds Seconds
1) (DiffTime -> DiffTime) -> DiffTime -> DiffTime
forall a b. (a -> b) -> a -> b
$ DiffTime
-> (NominalDiffTime -> DiffTime)
-> Maybe NominalDiffTime
-> DiffTime
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Minutes -> DiffTime
minutes Minutes
1) NominalDiffTime -> DiffTime
forall x y. (Duration x, Duration y) => x -> y
convertDuration Maybe NominalDiffTime
mTime
    IO () -> TraceT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> TraceT m ()) -> IO () -> TraceT m ()
forall a b. (a -> b) -> a -> b
$ DiffTime -> IO ()
C.sleep DiffTime
delay
  where
    logNotice :: TraceT m ()
logNotice = do
      let err :: JwkRefreshLog
err = LogLevel -> Maybe Text -> Maybe JwkFetchError -> JwkRefreshLog
JwkRefreshLog LogLevel
LevelInfo (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"retrying again in 60 secs") Maybe JwkFetchError
forall a. Maybe a
Nothing
      IO () -> TraceT m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> TraceT m ()) -> IO () -> TraceT m ()
forall a b. (a -> b) -> a -> b
$ Logger Hasura -> JwkRefreshLog -> IO ()
forall impl.
Logger impl
-> forall a (m :: * -> *).
   (ToEngineLog a impl, MonadIO m) =>
   a -> m ()
unLogger Logger Hasura
logger JwkRefreshLog
err

-- | Given a JWK url, fetch JWK from it and update the IORef
updateJwkRef ::
  ( MonadIO m,
    MonadBaseControl IO m,
    MonadError JwkFetchError m,
    Tracing.MonadTrace m
  ) =>
  Logger Hasura ->
  HTTP.Manager ->
  URI ->
  IORef Jose.JWKSet ->
  m (Maybe NominalDiffTime)
updateJwkRef :: Logger Hasura
-> Manager -> URI -> IORef JWKSet -> m (Maybe NominalDiffTime)
updateJwkRef (Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger) Manager
manager URI
url IORef JWKSet
jwkRef = do
  let urlT :: Text
urlT = URI -> Text
forall a. Show a => a -> Text
tshow URI
url
      infoMsg :: Text
infoMsg = Text
"refreshing JWK from endpoint: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
urlT
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ JwkRefreshLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger (JwkRefreshLog -> IO ()) -> JwkRefreshLog -> IO ()
forall a b. (a -> b) -> a -> b
$ LogLevel -> Maybe Text -> Maybe JwkFetchError -> JwkRefreshLog
JwkRefreshLog LogLevel
LevelInfo (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
infoMsg) Maybe JwkFetchError
forall a. Maybe a
Nothing
  Either HttpException (Response ByteString)
res <- m (Response ByteString)
-> m (Either HttpException (Response ByteString))
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try (m (Response ByteString)
 -> m (Either HttpException (Response ByteString)))
-> m (Response ByteString)
-> m (Either HttpException (Response ByteString))
forall a b. (a -> b) -> a -> b
$ do
    Request
req <- IO Request -> m Request
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Request -> m Request) -> IO Request -> m Request
forall a b. (a -> b) -> a -> b
$ Text -> IO Request
forall (m :: * -> *). MonadThrow m => Text -> m Request
HTTP.mkRequestThrow (Text -> IO Request) -> Text -> IO Request
forall a b. (a -> b) -> a -> b
$ URI -> Text
forall a. Show a => a -> Text
tshow URI
url
    let req' :: Request
req' = Request
req Request -> (Request -> Request) -> Request
forall a b. a -> (a -> b) -> b
& ASetter Request Request [Header] [Header]
-> ([Header] -> [Header]) -> Request -> Request
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter Request Request [Header] [Header]
Lens' Request [Header]
HTTP.headers [Header] -> [Header]
addDefaultHeaders

    Request
-> (Request -> m (Response ByteString)) -> m (Response ByteString)
forall (m :: * -> *) a.
MonadTrace m =>
Request -> (Request -> m a) -> m a
Tracing.tracedHttpRequest Request
req' \Request
req'' -> do
      IO (Response ByteString) -> m (Response ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Response ByteString) -> m (Response ByteString))
-> IO (Response ByteString) -> m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> Manager -> IO (Response ByteString)
HTTP.performRequest Request
req'' Manager
manager
  Response ByteString
resp <- Either HttpException (Response ByteString)
-> (HttpException -> m (Response ByteString))
-> m (Response ByteString)
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft Either HttpException (Response ByteString)
res HttpException -> m (Response ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
HttpException -> m a
logAndThrowHttp
  let status :: Status
status = Response ByteString
resp Response ByteString
-> Getting Status (Response ByteString) Status -> Status
forall s a. s -> Getting a s a -> a
^. Getting Status (Response ByteString) Status
forall body. Lens' (Response body) Status
Wreq.responseStatus
      respBody :: ByteString
respBody = Response ByteString
resp Response ByteString
-> Getting ByteString (Response ByteString) ByteString
-> ByteString
forall s a. s -> Getting a s a -> a
^. Getting ByteString (Response ByteString) ByteString
forall body0 body1.
Lens (Response body0) (Response body1) body0 body1
Wreq.responseBody
      statusCode :: Int
statusCode = Status
status Status -> Getting Int Status Int -> Int
forall s a. s -> Getting a s a -> a
^. Getting Int Status Int
Lens' Status Int
Wreq.statusCode

  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
statusCode Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
200 Bool -> Bool -> Bool
&& Int
statusCode Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
300) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    let errMsg :: Text
errMsg = Text
"Non-2xx response on fetching JWK from: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
urlT
        err :: JwkFetchError
err = URI -> Status -> ByteString -> Text -> JwkFetchError
JFEHttpError URI
url Status
status ByteString
respBody Text
errMsg
    JwkFetchError -> m ()
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrow JwkFetchError
err

  let parseErr :: String -> JwkFetchError
parseErr String
e = Text -> Text -> JwkFetchError
JFEJwkParseError (String -> Text
T.pack String
e) (Text -> JwkFetchError) -> Text -> JwkFetchError
forall a b. (a -> b) -> a -> b
$ Text
"Error parsing JWK from url: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
urlT
  !JWKSet
jwkset <- Either String JWKSet -> (String -> m JWKSet) -> m JWKSet
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft (ByteString -> Either String JWKSet
forall a. FromJSON a => ByteString -> Either String a
J.eitherDecode' ByteString
respBody) (JwkFetchError -> m JWKSet
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrow (JwkFetchError -> m JWKSet)
-> (String -> JwkFetchError) -> String -> m JWKSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> JwkFetchError
parseErr)
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    String
String -> JWKSet -> IO ()
forall a. String -> a -> IO ()
$assertNFHere JWKSet
jwkset -- so we don't write thunks to mutable vars
    IORef JWKSet -> JWKSet -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef JWKSet
jwkRef JWKSet
jwkset

  m UTCTime -> Logger Hasura -> [Header] -> m (Maybe NominalDiffTime)
forall (m :: * -> *).
(MonadIO m, MonadError JwkFetchError m) =>
m UTCTime -> Logger Hasura -> [Header] -> m (Maybe NominalDiffTime)
determineJwkExpiryLifetime (IO UTCTime -> m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime) ((forall a (m :: * -> *).
 (ToEngineLog a Hasura, MonadIO m) =>
 a -> m ())
-> Logger Hasura
forall impl.
(forall a (m :: * -> *).
 (ToEngineLog a impl, MonadIO m) =>
 a -> m ())
-> Logger impl
Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger) (Response ByteString
resp Response ByteString
-> Getting [Header] (Response ByteString) [Header] -> [Header]
forall s a. s -> Getting a s a -> a
^. Getting [Header] (Response ByteString) [Header]
forall body. Lens' (Response body) [Header]
Wreq.responseHeaders)
  where
    logAndThrow :: (MonadIO m, MonadError JwkFetchError m) => JwkFetchError -> m a
    logAndThrow :: JwkFetchError -> m a
logAndThrow JwkFetchError
err = do
      IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ JwkRefreshLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger (JwkRefreshLog -> IO ()) -> JwkRefreshLog -> IO ()
forall a b. (a -> b) -> a -> b
$ LogLevel -> Maybe Text -> Maybe JwkFetchError -> JwkRefreshLog
JwkRefreshLog (Text -> LogLevel
LevelOther Text
"critical") Maybe Text
forall a. Maybe a
Nothing (JwkFetchError -> Maybe JwkFetchError
forall a. a -> Maybe a
Just JwkFetchError
err)
      JwkFetchError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError JwkFetchError
err

    logAndThrowHttp :: (MonadIO m, MonadError JwkFetchError m) => HTTP.HttpException -> m a
    logAndThrowHttp :: HttpException -> m a
logAndThrowHttp HttpException
httpEx = do
      let errMsg :: Text
errMsg = Text
"Error fetching JWK: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (HttpException -> String
getHttpExceptionMsg HttpException
httpEx)
          err :: JwkFetchError
err = HttpException -> Text -> JwkFetchError
JFEHttpException (HttpException -> HttpException
HttpException HttpException
httpEx) Text
errMsg
      JwkFetchError -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrow JwkFetchError
err

    getHttpExceptionMsg :: HttpException -> String
getHttpExceptionMsg = \case
      HTTP.HttpExceptionRequest Request
_ HttpExceptionContent
reason -> HttpExceptionContent -> String
forall a. Show a => a -> String
show HttpExceptionContent
reason
      HTTP.InvalidUrlException String
_ String
reason -> ShowS
forall a. Show a => a -> String
show String
reason

-- | First check for Cache-Control header, if not found, look for Expires header
determineJwkExpiryLifetime ::
  forall m.
  (MonadIO m, MonadError JwkFetchError m) =>
  m UTCTime ->
  Logger Hasura ->
  ResponseHeaders ->
  m (Maybe NominalDiffTime)
determineJwkExpiryLifetime :: m UTCTime -> Logger Hasura -> [Header] -> m (Maybe NominalDiffTime)
determineJwkExpiryLifetime m UTCTime
getCurrentTime' (Logger forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger) [Header]
responseHeaders =
  MaybeT m NominalDiffTime -> m (Maybe NominalDiffTime)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m NominalDiffTime -> m (Maybe NominalDiffTime))
-> MaybeT m NominalDiffTime -> m (Maybe NominalDiffTime)
forall a b. (a -> b) -> a -> b
$ MaybeT m NominalDiffTime
timeFromCacheControl MaybeT m NominalDiffTime
-> MaybeT m NominalDiffTime -> MaybeT m NominalDiffTime
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> MaybeT m NominalDiffTime
timeFromExpires
  where
    parseCacheControlErr :: Text -> JwkFetchError
    parseCacheControlErr :: Text -> JwkFetchError
parseCacheControlErr Text
e =
      Maybe Text -> Text -> JwkFetchError
JFEExpiryParseError
        (Text -> Maybe Text
forall a. a -> Maybe a
Just Text
e)
        Text
"Failed parsing Cache-Control header from JWK response"

    parseTimeErr :: JwkFetchError
    parseTimeErr :: JwkFetchError
parseTimeErr =
      Maybe Text -> Text -> JwkFetchError
JFEExpiryParseError
        Maybe Text
forall a. Maybe a
Nothing
        Text
"Failed parsing Expires header from JWK response. Value of header is not a valid timestamp"

    timeFromCacheControl :: MaybeT m NominalDiffTime
    timeFromCacheControl :: MaybeT m NominalDiffTime
timeFromCacheControl = do
      Text
header <- Maybe Text -> MaybeT m Text
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t a -> f a
afold (Maybe Text -> MaybeT m Text) -> Maybe Text -> MaybeT m Text
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
bsToTxt (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Cache-Control" [Header]
responseHeaders
      CacheControl
cacheControl <- Text -> Either String CacheControl
parseCacheControl Text
header Either String CacheControl
-> (String -> MaybeT m CacheControl) -> MaybeT m CacheControl
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` \String
err -> JwkFetchError -> MaybeT m CacheControl
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrowInfo (JwkFetchError -> MaybeT m CacheControl)
-> JwkFetchError -> MaybeT m CacheControl
forall a b. (a -> b) -> a -> b
$ Text -> JwkFetchError
parseCacheControlErr (Text -> JwkFetchError) -> Text -> JwkFetchError
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
err
      Maybe NominalDiffTime
maxAgeMaybe <- (Integer -> NominalDiffTime)
-> Maybe Integer -> Maybe NominalDiffTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> NominalDiffTime
forall a. Num a => Integer -> a
fromInteger (Maybe Integer -> Maybe NominalDiffTime)
-> MaybeT m (Maybe Integer) -> MaybeT m (Maybe NominalDiffTime)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CacheControl -> Either String (Maybe Integer)
forall a. Integral a => CacheControl -> Either String (Maybe a)
findMaxAge CacheControl
cacheControl Either String (Maybe Integer)
-> (String -> MaybeT m (Maybe Integer)) -> MaybeT m (Maybe Integer)
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` \String
err -> JwkFetchError -> MaybeT m (Maybe Integer)
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrowInfo (JwkFetchError -> MaybeT m (Maybe Integer))
-> JwkFetchError -> MaybeT m (Maybe Integer)
forall a b. (a -> b) -> a -> b
$ Text -> JwkFetchError
parseCacheControlErr (Text -> JwkFetchError) -> Text -> JwkFetchError
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
err
      if
          -- If a max-age is specified with a must-revalidate we use it, but if not we use an immediate expiry time
          | CacheControl -> Bool
mustRevalidateExists CacheControl
cacheControl -> NominalDiffTime -> MaybeT m NominalDiffTime
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NominalDiffTime -> MaybeT m NominalDiffTime)
-> NominalDiffTime -> MaybeT m NominalDiffTime
forall a b. (a -> b) -> a -> b
$ NominalDiffTime -> Maybe NominalDiffTime -> NominalDiffTime
forall a. a -> Maybe a -> a
fromMaybe NominalDiffTime
0 Maybe NominalDiffTime
maxAgeMaybe
          -- In these cases we want don't want to cache the JWK, so we use an immediate expiry time
          | CacheControl -> Bool
noCacheExists CacheControl
cacheControl Bool -> Bool -> Bool
|| CacheControl -> Bool
noStoreExists CacheControl
cacheControl -> NominalDiffTime -> MaybeT m NominalDiffTime
forall (f :: * -> *) a. Applicative f => a -> f a
pure NominalDiffTime
0
          -- Use max-age, if it exists
          | Bool
otherwise -> Maybe NominalDiffTime -> MaybeT m NominalDiffTime
forall (m :: * -> *) b. Applicative m => Maybe b -> MaybeT m b
hoistMaybe Maybe NominalDiffTime
maxAgeMaybe

    timeFromExpires :: MaybeT m NominalDiffTime
    timeFromExpires :: MaybeT m NominalDiffTime
timeFromExpires = do
      Text
header <- Maybe Text -> MaybeT m Text
forall (t :: * -> *) (f :: * -> *) a.
(Foldable t, Alternative f) =>
t a -> f a
afold (Maybe Text -> MaybeT m Text) -> Maybe Text -> MaybeT m Text
forall a b. (a -> b) -> a -> b
$ ByteString -> Text
bsToTxt (ByteString -> Text) -> Maybe ByteString -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HeaderName -> [Header] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Expires" [Header]
responseHeaders
      UTCTime
expiry <- Text -> Either String UTCTime
forall (m :: * -> *). MonadError String m => Text -> m UTCTime
parseExpirationTime Text
header Either String UTCTime
-> (String -> MaybeT m UTCTime) -> MaybeT m UTCTime
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` MaybeT m UTCTime -> String -> MaybeT m UTCTime
forall a b. a -> b -> a
const (JwkFetchError -> MaybeT m UTCTime
forall (m :: * -> *) a.
(MonadIO m, MonadError JwkFetchError m) =>
JwkFetchError -> m a
logAndThrowInfo JwkFetchError
parseTimeErr)
      UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
expiry (UTCTime -> NominalDiffTime)
-> MaybeT m UTCTime -> MaybeT m NominalDiffTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m UTCTime -> MaybeT m UTCTime
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m UTCTime
getCurrentTime'

    logAndThrowInfo :: (MonadIO m1, MonadError JwkFetchError m1) => JwkFetchError -> m1 a
    logAndThrowInfo :: JwkFetchError -> m1 a
logAndThrowInfo JwkFetchError
err = do
      IO () -> m1 ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m1 ()) -> IO () -> m1 ()
forall a b. (a -> b) -> a -> b
$ JwkRefreshLog -> IO ()
forall a (m :: * -> *).
(ToEngineLog a Hasura, MonadIO m) =>
a -> m ()
logger (JwkRefreshLog -> IO ()) -> JwkRefreshLog -> IO ()
forall a b. (a -> b) -> a -> b
$ LogLevel -> Maybe Text -> Maybe JwkFetchError -> JwkRefreshLog
JwkRefreshLog LogLevel
LevelInfo Maybe Text
forall a. Maybe a
Nothing (JwkFetchError -> Maybe JwkFetchError
forall a. a -> Maybe a
Just JwkFetchError
err)
      JwkFetchError -> m1 a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError JwkFetchError
err

type ClaimsMap = HM.HashMap SessionVariable J.Value

-- | Decode a Jose ClaimsSet without verifying the signature
decodeClaimsSet :: RawJWT -> Maybe Jose.ClaimsSet
decodeClaimsSet :: RawJWT -> Maybe ClaimsSet
decodeClaimsSet (RawJWT ByteString
jwt) = do
  (ByteString
_, ByteString
c, ByteString
_) <- [ByteString] -> Maybe (ByteString, ByteString, ByteString)
forall c. [c] -> Maybe (c, c, c)
extractElems ([ByteString] -> Maybe (ByteString, ByteString, ByteString))
-> [ByteString] -> Maybe (ByteString, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> [ByteString]
BL.splitWith (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Char -> Word8
B.c2w Char
'.') ByteString
jwt
  case Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
BAE.convertFromBase Base
BAE.Base64URLUnpadded (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.toStrict ByteString
c of
    Left String
_ -> Maybe ClaimsSet
forall a. Maybe a
Nothing
    Right ByteString
s -> ByteString -> Maybe ClaimsSet
forall a. FromJSON a => ByteString -> Maybe a
J.decode (ByteString -> Maybe ClaimsSet) -> ByteString -> Maybe ClaimsSet
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
s
  where
    extractElems :: [c] -> Maybe (c, c, c)
extractElems (c
h : c
c : c
s : [c]
_) = (c, c, c) -> Maybe (c, c, c)
forall a. a -> Maybe a
Just (c
h, c
c, c
s)
    extractElems [c]
_ = Maybe (c, c, c)
forall a. Maybe a
Nothing

-- | Extract the issuer from a bearer tokena _without_ verifying it.
tokenIssuer :: RawJWT -> Maybe StringOrURI
tokenIssuer :: RawJWT -> Maybe StringOrURI
tokenIssuer = Maybe StringOrURI -> Maybe StringOrURI
coerce (Maybe StringOrURI -> Maybe StringOrURI)
-> (RawJWT -> Maybe StringOrURI) -> RawJWT -> Maybe StringOrURI
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (RawJWT -> Maybe ClaimsSet
decodeClaimsSet (RawJWT -> Maybe ClaimsSet)
-> (ClaimsSet -> Maybe StringOrURI) -> RawJWT -> Maybe StringOrURI
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Getting (Maybe StringOrURI) ClaimsSet (Maybe StringOrURI)
-> ClaimsSet -> Maybe StringOrURI
forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Getting (Maybe StringOrURI) ClaimsSet (Maybe StringOrURI)
Lens' ClaimsSet (Maybe StringOrURI)
Jose.claimIss)

-- | Process the request headers to verify the JWT and extract UserInfo from it
-- From the JWT config, we check which header to expect, it can be the "Authorization"
-- or "Cookie" header
--
-- Iff no "Authorization"/"Cookie" header was passed, we will fall back to the
-- unauthenticated user role [1], if one was configured at server start.
--
-- When no 'x-hasura-user-role' is specified in the request, the mandatory
-- 'x-hasura-default-role' [2] from the JWT claims will be used.

-- [1]: https://hasura.io/docs/latest/graphql/core/auth/authentication/unauthenticated-access.html
-- [2]: https://hasura.io/docs/latest/graphql/core/auth/authentication/jwt.html#the-spec
processJwt ::
  ( MonadIO m,
    MonadError QErr m
  ) =>
  [JWTCtx] ->
  HTTP.RequestHeaders ->
  Maybe RoleName ->
  m (UserInfo, Maybe UTCTime, [N.Header])
processJwt :: [JWTCtx]
-> [Header]
-> Maybe RoleName
-> m (UserInfo, Maybe UTCTime, [Header])
processJwt = (JWTCtx
 -> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime))
-> (RawJWT -> Maybe StringOrURI)
-> (JWTCtx -> JWTHeader)
-> [JWTCtx]
-> [Header]
-> Maybe RoleName
-> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *).
MonadError QErr m =>
(JWTCtx
 -> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime))
-> (RawJWT -> Maybe StringOrURI)
-> (JWTCtx -> JWTHeader)
-> [JWTCtx]
-> [Header]
-> Maybe RoleName
-> m (UserInfo, Maybe UTCTime, [Header])
processJwt_ JWTCtx
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
forall (m :: * -> *).
(MonadIO m, MonadError QErr m) =>
JWTCtx
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
processHeaderSimple RawJWT -> Maybe StringOrURI
tokenIssuer JWTCtx -> JWTHeader
jcxHeader

type AuthTokenLocation = JWTHeader

-- Broken out for testing with mocks:
processJwt_ ::
  (MonadError QErr m) =>
  -- | mock 'processAuthZOrCookieHeader'
  (JWTCtx -> BLC.ByteString -> m (ClaimsMap, Maybe UTCTime)) ->
  (RawJWT -> Maybe StringOrURI) ->
  (JWTCtx -> JWTHeader) ->
  [JWTCtx] ->
  HTTP.RequestHeaders ->
  Maybe RoleName ->
  m (UserInfo, Maybe UTCTime, [N.Header])
processJwt_ :: (JWTCtx
 -> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime))
-> (RawJWT -> Maybe StringOrURI)
-> (JWTCtx -> JWTHeader)
-> [JWTCtx]
-> [Header]
-> Maybe RoleName
-> m (UserInfo, Maybe UTCTime, [Header])
processJwt_ JWTCtx
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
processJwtBytes RawJWT -> Maybe StringOrURI
decodeIssuer JWTCtx -> JWTHeader
fGetHeaderType [JWTCtx]
jwtCtxs [Header]
headers Maybe RoleName
mUnAuthRole = do
  -- Here we use `intersectKeys` to match up the correct locations of JWTs to those specified in JWTCtxs
  -- Then we match up issuers, where no-issuer specified in a JWTCtx can match any issuer in a JWT
  -- Then there should either be zero matches - Perform no auth
  -- Or one match - Perform normal auth
  -- Otherwise there is an ambiguous situation which we currently treat as an error.
  [Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)]
issuerMatches <- ((JWTCtx, (JWTHeader, ByteString))
 -> m (Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)))
-> [(JWTCtx, (JWTHeader, ByteString))]
-> m [Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (JWTCtx, (JWTHeader, ByteString))
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
issuerMatch ([(JWTCtx, (JWTHeader, ByteString))]
 -> m [Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)])
-> [(JWTCtx, (JWTHeader, ByteString))]
-> m [Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString)]
forall a b. (a -> b) -> a -> b
$ HashMap JWTHeader [JWTCtx]
-> HashMap JWTHeader [(JWTHeader, ByteString)]
-> [(JWTCtx, (JWTHeader, ByteString))]
forall a b c.
(Hashable a, Eq a) =>
HashMap a [b] -> HashMap a [c] -> [(b, c)]
intersectKeys ([JWTCtx] -> HashMap JWTHeader [JWTCtx]
keyCtxOnAuthTypes [JWTCtx]
jwtCtxs) ([Header] -> HashMap JWTHeader [(JWTHeader, ByteString)]
keyTokensOnAuthTypes [Header]
headers)

  case ([Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)]
-> [(Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)]
forall a b. [Either a b] -> [a]
lefts [Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)]
issuerMatches, [Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)]
-> [(JWTCtx, ByteString)]
forall a b. [Either a b] -> [b]
rights [Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)]
issuerMatches) of
    ([], []) -> m (UserInfo, Maybe UTCTime, [Header])
withoutAuthZ
    ((Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
_ : [(Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)]
_, []) -> m (UserInfo, Maybe UTCTime, [Header])
forall a. m a
jwtNotIssuerError
    ([(Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)]
_, [(JWTCtx
ctx, ByteString
val)]) -> ByteString -> JWTCtx -> m (UserInfo, Maybe UTCTime, [Header])
withAuthZ ByteString
val JWTCtx
ctx
    ([(Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)],
 [(JWTCtx, ByteString)])
_ -> Code -> Text -> m (UserInfo, Maybe UTCTime, [Header])
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
InvalidHeaders Text
"Could not verify JWT: Multiple JWTs found"
  where
    intersectKeys :: (Hashable a, Eq a) => HM.HashMap a [b] -> HM.HashMap a [c] -> [(b, c)]
    intersectKeys :: HashMap a [b] -> HashMap a [c] -> [(b, c)]
intersectKeys HashMap a [b]
m HashMap a [c]
n = (([b], [c]) -> [(b, c)]) -> [([b], [c])] -> [(b, c)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (([b] -> [c] -> [(b, c)]) -> ([b], [c]) -> [(b, c)]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [b] -> [c] -> [(b, c)]
forall a b. [a] -> [b] -> [(a, b)]
cartesianProduct) ([([b], [c])] -> [(b, c)]) -> [([b], [c])] -> [(b, c)]
forall a b. (a -> b) -> a -> b
$ HashMap a ([b], [c]) -> [([b], [c])]
forall k v. HashMap k v -> [v]
HM.elems (HashMap a ([b], [c]) -> [([b], [c])])
-> HashMap a ([b], [c]) -> [([b], [c])]
forall a b. (a -> b) -> a -> b
$ ([b] -> [c] -> ([b], [c]))
-> HashMap a [b] -> HashMap a [c] -> HashMap a ([b], [c])
forall k v1 v2 v3.
(Eq k, Hashable k) =>
(v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
HM.intersectionWith (,) HashMap a [b]
m HashMap a [c]
n

    issuerMatch :: (JWTCtx, (JWTHeader, ByteString))
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
issuerMatch (JWTCtx
j, (JWTHeader, ByteString)
b) = do
      ByteString
b'' <- case (JWTHeader, ByteString)
b of
        (JHCookie Text
_, ByteString
b') -> ByteString -> m ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
b'
        (JWTHeader
JHAuthorization, ByteString
b') ->
          case ByteString -> [ByteString]
BC.words ByteString
b' of
            [ByteString
"Bearer", ByteString
jwt] -> ByteString -> m ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
jwt
            [ByteString]
_ -> Code -> Text -> m ByteString
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
InvalidHeaders Text
"Malformed Authorization header"

      case (StringOrURI -> StringOrURI
StringOrURI (StringOrURI -> StringOrURI)
-> Maybe StringOrURI -> Maybe StringOrURI
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JWTCtx -> Maybe StringOrURI
jcxIssuer JWTCtx
j, RawJWT -> Maybe StringOrURI
decodeIssuer (RawJWT -> Maybe StringOrURI) -> RawJWT -> Maybe StringOrURI
forall a b. (a -> b) -> a -> b
$ ByteString -> RawJWT
RawJWT (ByteString -> RawJWT) -> ByteString -> RawJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BLC.fromStrict ByteString
b'') of
        (Maybe StringOrURI
Nothing, Maybe StringOrURI
_) -> Either
  (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
  (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)
 -> m (Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)))
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall a b. (a -> b) -> a -> b
$ (JWTCtx, ByteString)
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
forall a b. b -> Either a b
Right (JWTCtx
j, ByteString
b'')
        (Maybe StringOrURI
_, Maybe StringOrURI
Nothing) -> Either
  (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
  (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)
 -> m (Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)))
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall a b. (a -> b) -> a -> b
$ (JWTCtx, ByteString)
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
forall a b. b -> Either a b
Right (JWTCtx
j, ByteString
b'')
        (Maybe StringOrURI
ci, Maybe StringOrURI
ji)
          | Maybe StringOrURI
ci Maybe StringOrURI -> Maybe StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe StringOrURI
ji -> Either
  (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
  (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)
 -> m (Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)))
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall a b. (a -> b) -> a -> b
$ (JWTCtx, ByteString)
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
forall a b. b -> Either a b
Right (JWTCtx
j, ByteString
b'')
          | Bool
otherwise -> Either
  (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
  (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
   (JWTCtx, ByteString)
 -> m (Either
         (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
         (JWTCtx, ByteString)))
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
-> m (Either
        (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
        (JWTCtx, ByteString))
forall a b. (a -> b) -> a -> b
$ (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
-> Either
     (Maybe StringOrURI, Maybe StringOrURI, JWTCtx, ByteString)
     (JWTCtx, ByteString)
forall a b. a -> Either a b
Left (Maybe StringOrURI
ci, Maybe StringOrURI
ji, JWTCtx
j, ByteString
b'')

    cartesianProduct :: [a] -> [b] -> [(a, b)]
    cartesianProduct :: [a] -> [b] -> [(a, b)]
cartesianProduct [a]
as [b]
bs = [(a
a, b
b) | a
a <- [a]
as, b
b <- [b]
bs]

    keyCtxOnAuthTypes :: [JWTCtx] -> HM.HashMap AuthTokenLocation [JWTCtx]
    keyCtxOnAuthTypes :: [JWTCtx] -> HashMap JWTHeader [JWTCtx]
keyCtxOnAuthTypes = ([JWTCtx] -> [JWTCtx] -> [JWTCtx])
-> [(JWTHeader, [JWTCtx])] -> HashMap JWTHeader [JWTCtx]
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> [(k, v)] -> HashMap k v
HM.fromListWith [JWTCtx] -> [JWTCtx] -> [JWTCtx]
forall a. [a] -> [a] -> [a]
(++) ([(JWTHeader, [JWTCtx])] -> HashMap JWTHeader [JWTCtx])
-> ([JWTCtx] -> [(JWTHeader, [JWTCtx])])
-> [JWTCtx]
-> HashMap JWTHeader [JWTCtx]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (JWTCtx -> (JWTHeader, [JWTCtx]))
-> [JWTCtx] -> [(JWTHeader, [JWTCtx])]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (JWTCtx -> JWTHeader
expectedHeader (JWTCtx -> JWTHeader)
-> (JWTCtx -> [JWTCtx]) -> JWTCtx -> (JWTHeader, [JWTCtx])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& JWTCtx -> [JWTCtx]
forall (f :: * -> *) a. Applicative f => a -> f a
pure)

    keyTokensOnAuthTypes :: [HTTP.Header] -> HM.HashMap AuthTokenLocation [(AuthTokenLocation, B.ByteString)]
    keyTokensOnAuthTypes :: [Header] -> HashMap JWTHeader [(JWTHeader, ByteString)]
keyTokensOnAuthTypes = ([(JWTHeader, ByteString)]
 -> [(JWTHeader, ByteString)] -> [(JWTHeader, ByteString)])
-> [(JWTHeader, [(JWTHeader, ByteString)])]
-> HashMap JWTHeader [(JWTHeader, ByteString)]
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> [(k, v)] -> HashMap k v
HM.fromListWith [(JWTHeader, ByteString)]
-> [(JWTHeader, ByteString)] -> [(JWTHeader, ByteString)]
forall a. [a] -> [a] -> [a]
(++) ([(JWTHeader, [(JWTHeader, ByteString)])]
 -> HashMap JWTHeader [(JWTHeader, ByteString)])
-> ([Header] -> [(JWTHeader, [(JWTHeader, ByteString)])])
-> [Header]
-> HashMap JWTHeader [(JWTHeader, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((JWTHeader, ByteString) -> (JWTHeader, [(JWTHeader, ByteString)]))
-> [(JWTHeader, ByteString)]
-> [(JWTHeader, [(JWTHeader, ByteString)])]
forall a b. (a -> b) -> [a] -> [b]
map ((JWTHeader, ByteString) -> JWTHeader
forall a b. (a, b) -> a
fst ((JWTHeader, ByteString) -> JWTHeader)
-> ((JWTHeader, ByteString) -> [(JWTHeader, ByteString)])
-> (JWTHeader, ByteString)
-> (JWTHeader, [(JWTHeader, ByteString)])
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& (JWTHeader, ByteString) -> [(JWTHeader, ByteString)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure) ([(JWTHeader, ByteString)]
 -> [(JWTHeader, [(JWTHeader, ByteString)])])
-> ([Header] -> [(JWTHeader, ByteString)])
-> [Header]
-> [(JWTHeader, [(JWTHeader, ByteString)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> [(JWTHeader, ByteString)])
-> [Header] -> [(JWTHeader, ByteString)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Header -> [(JWTHeader, ByteString)]
findTokensInHeader

    findTokensInHeader :: Header -> [(AuthTokenLocation, B.ByteString)]
    findTokensInHeader :: Header -> [(JWTHeader, ByteString)]
findTokensInHeader (HeaderName
key, ByteString
val)
      | HeaderName
key HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk ByteString
"Authorization" = [(JWTHeader
JHAuthorization, ByteString
val)]
      | HeaderName
key HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk ByteString
"Cookie" = (Text -> JWTHeader)
-> (Text -> ByteString) -> (Text, Text) -> (JWTHeader, ByteString)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Text -> JWTHeader
JHCookie Text -> ByteString
T.encodeUtf8 ((Text, Text) -> (JWTHeader, ByteString))
-> [(Text, Text)] -> [(JWTHeader, ByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> [(Text, Text)]
Spock.parseCookies ByteString
val
      | Bool
otherwise = []

    expectedHeader :: JWTCtx -> AuthTokenLocation
    expectedHeader :: JWTCtx -> JWTHeader
expectedHeader JWTCtx
jwtCtx =
      case JWTCtx -> JWTHeader
fGetHeaderType JWTCtx
jwtCtx of
        JWTHeader
JHAuthorization -> JWTHeader
JHAuthorization
        JHCookie Text
name -> Text -> JWTHeader
JHCookie Text
name

    withAuthZ :: ByteString -> JWTCtx -> m (UserInfo, Maybe UTCTime, [Header])
withAuthZ ByteString
authzHeader JWTCtx
jwtCtx = do
      (HashMap SessionVariable Value, Maybe UTCTime)
authMode <- JWTCtx
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
processJwtBytes JWTCtx
jwtCtx (ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime))
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
authzHeader

      let (HashMap SessionVariable Value
claimsMap, Maybe UTCTime
expTimeM) = (HashMap SessionVariable Value, Maybe UTCTime)
authMode
       in do
            HasuraClaims [RoleName]
allowedRoles RoleName
defaultRole <- HashMap SessionVariable Value -> m HasuraClaims
forall (m :: * -> *).
MonadError QErr m =>
HashMap SessionVariable Value -> m HasuraClaims
parseHasuraClaims HashMap SessionVariable Value
claimsMap
            -- see if there is a x-hasura-role header, or else pick the default role.
            -- The role returned is unauthenticated at this point:
            let requestedRole :: RoleName
requestedRole =
                  RoleName -> Maybe RoleName -> RoleName
forall a. a -> Maybe a -> a
fromMaybe RoleName
defaultRole (Maybe RoleName -> RoleName) -> Maybe RoleName -> RoleName
forall a b. (a -> b) -> a -> b
$
                    HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
forall a. IsString a => a
userRoleHeader [Header]
headers Maybe ByteString
-> (ByteString -> Maybe RoleName) -> Maybe RoleName
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> Maybe RoleName
mkRoleName (Text -> Maybe RoleName)
-> (ByteString -> Text) -> ByteString -> Maybe RoleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
bsToTxt

            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RoleName
requestedRole RoleName -> [RoleName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [RoleName]
allowedRoles) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
              Code -> Text -> m ()
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
AccessDenied Text
"Your requested role is not in allowed roles"
            let finalClaims :: HashMap SessionVariable Value
finalClaims =
                  SessionVariable
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete SessionVariable
defaultRoleClaim (HashMap SessionVariable Value -> HashMap SessionVariable Value)
-> (HashMap SessionVariable Value -> HashMap SessionVariable Value)
-> HashMap SessionVariable Value
-> HashMap SessionVariable Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionVariable
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> HashMap k v
HM.delete SessionVariable
allowedRolesClaim (HashMap SessionVariable Value -> HashMap SessionVariable Value)
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall a b. (a -> b) -> a -> b
$ HashMap SessionVariable Value
claimsMap

            let finalClaimsObject :: Object
finalClaimsObject =
                  [Pair] -> Object
forall v. [(Key, v)] -> KeyMap v
KM.fromList ([Pair] -> Object) -> [Pair] -> Object
forall a b. (a -> b) -> a -> b
$
                    ((SessionVariable, Value) -> Pair)
-> [(SessionVariable, Value)] -> [Pair]
forall a b. (a -> b) -> [a] -> [b]
map ((SessionVariable -> Key) -> (SessionVariable, Value) -> Pair
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (Text -> Key
K.fromText (Text -> Key)
-> (SessionVariable -> Text) -> SessionVariable -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionVariable -> Text
sessionVariableToText)) ([(SessionVariable, Value)] -> [Pair])
-> [(SessionVariable, Value)] -> [Pair]
forall a b. (a -> b) -> a -> b
$
                      HashMap SessionVariable Value -> [(SessionVariable, Value)]
forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap SessionVariable Value
finalClaims
            HashMap Text Text
metadata <- Value -> Text -> m (HashMap Text Text)
forall a (m :: * -> *).
(FromJSON a, MonadError QErr m) =>
Value -> Text -> m a
parseJwtClaim (Object -> Value
J.Object Object
finalClaimsObject) Text
"x-hasura-* claims"
            UserInfo
userInfo <-
              UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
forall (m :: * -> *).
MonadError QErr m =>
UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
mkUserInfo (RoleName -> UserRoleBuild
URBPreDetermined RoleName
requestedRole) UserAdminSecret
UAdminSecretNotSent (SessionVariables -> m UserInfo) -> SessionVariables -> m UserInfo
forall a b. (a -> b) -> a -> b
$
                HashMap Text Text -> SessionVariables
mkSessionVariablesText HashMap Text Text
metadata
            (UserInfo, Maybe UTCTime, [Header])
-> m (UserInfo, Maybe UTCTime, [Header])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UserInfo
userInfo, Maybe UTCTime
expTimeM, [])

    withoutAuthZ :: m (UserInfo, Maybe UTCTime, [Header])
withoutAuthZ = do
      RoleName
unAuthRole <- Maybe RoleName -> m RoleName -> m RoleName
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing Maybe RoleName
mUnAuthRole (Code -> Text -> m RoleName
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
InvalidHeaders Text
"Missing 'Authorization' or 'Cookie' header in JWT authentication mode")
      UserInfo
userInfo <-
        UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
forall (m :: * -> *).
MonadError QErr m =>
UserRoleBuild -> UserAdminSecret -> SessionVariables -> m UserInfo
mkUserInfo (RoleName -> UserRoleBuild
URBPreDetermined RoleName
unAuthRole) UserAdminSecret
UAdminSecretNotSent (SessionVariables -> m UserInfo) -> SessionVariables -> m UserInfo
forall a b. (a -> b) -> a -> b
$
          [Header] -> SessionVariables
mkSessionVariablesHeaders [Header]
headers
      (UserInfo, Maybe UTCTime, [Header])
-> m (UserInfo, Maybe UTCTime, [Header])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UserInfo
userInfo, Maybe UTCTime
forall a. Maybe a
Nothing, [])

    jwtNotIssuerError :: m a
jwtNotIssuerError = Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTInvalid Text
"Could not verify JWT: JWTNotInIssuer"

-- | Processes a token payload (excluding the `Bearer ` prefix in the context of a JWTCtx)
processHeaderSimple ::
  ( MonadIO m,
    MonadError QErr m
  ) =>
  JWTCtx ->
  BLC.ByteString ->
  -- The "Maybe" in "m (Maybe (...))" covers the case where the
  -- requested Cookie name is not present (returns "m Nothing")
  m (ClaimsMap, Maybe UTCTime)
processHeaderSimple :: JWTCtx
-> ByteString -> m (HashMap SessionVariable Value, Maybe UTCTime)
processHeaderSimple JWTCtx
jwtCtx ByteString
jwt = do
  --iss <- _ <$> Jose.decodeCompact (BL.fromStrict token)
  --let ctx = M.lookup iss jwtCtx

  -- try to parse JWT token from Authorization or Cookie header
  -- verify the JWT
  ClaimsSet
claims <- (JWTError -> QErr) -> ExceptT JWTError m ClaimsSet -> m ClaimsSet
forall e' (m :: * -> *) e a.
MonadError e' m =>
(e -> e') -> ExceptT e m a -> m a
liftJWTError JWTError -> QErr
forall a. Show a => a -> QErr
invalidJWTError (ExceptT JWTError m ClaimsSet -> m ClaimsSet)
-> ExceptT JWTError m ClaimsSet -> m ClaimsSet
forall a b. (a -> b) -> a -> b
$ JWTCtx -> RawJWT -> ExceptT JWTError m ClaimsSet
forall (m :: * -> *).
(MonadError JWTError m, MonadIO m) =>
JWTCtx -> RawJWT -> m ClaimsSet
verifyJwt JWTCtx
jwtCtx (RawJWT -> ExceptT JWTError m ClaimsSet)
-> RawJWT -> ExceptT JWTError m ClaimsSet
forall a b. (a -> b) -> a -> b
$ ByteString -> RawJWT
RawJWT ByteString
jwt

  let expTimeM :: Maybe UTCTime
expTimeM = (NumericDate -> UTCTime) -> Maybe NumericDate -> Maybe UTCTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Jose.NumericDate UTCTime
t) -> UTCTime
t) (Maybe NumericDate -> Maybe UTCTime)
-> Maybe NumericDate -> Maybe UTCTime
forall a b. (a -> b) -> a -> b
$ ClaimsSet
claims ClaimsSet
-> Getting (Maybe NumericDate) ClaimsSet (Maybe NumericDate)
-> Maybe NumericDate
forall s a. s -> Getting a s a -> a
^. Getting (Maybe NumericDate) ClaimsSet (Maybe NumericDate)
Lens' ClaimsSet (Maybe NumericDate)
Jose.claimExp

  HashMap SessionVariable Value
claimsObject <- ClaimsSet -> JWTClaims -> m (HashMap SessionVariable Value)
forall (m :: * -> *).
MonadError QErr m =>
ClaimsSet -> JWTClaims -> m (HashMap SessionVariable Value)
parseClaimsMap ClaimsSet
claims JWTClaims
claimsConfig

  (HashMap SessionVariable Value, Maybe UTCTime)
-> m (HashMap SessionVariable Value, Maybe UTCTime)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashMap SessionVariable Value
claimsObject, Maybe UTCTime
expTimeM)
  where
    claimsConfig :: JWTClaims
claimsConfig = JWTCtx -> JWTClaims
jcxClaims JWTCtx
jwtCtx

    liftJWTError :: (MonadError e' m) => (e -> e') -> ExceptT e m a -> m a
    liftJWTError :: (e -> e') -> ExceptT e m a -> m a
liftJWTError e -> e'
ef ExceptT e m a
action = do
      Either e a
res <- ExceptT e m a -> m (Either e a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT e m a
action
      Either e a -> (e -> m a) -> m a
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft Either e a
res (e' -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (e' -> m a) -> (e -> e') -> e -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> e'
ef)

    invalidJWTError :: a -> QErr
invalidJWTError a
e = Code -> Text -> QErr
err400 Code
JWTInvalid (Text -> QErr) -> Text -> QErr
forall a b. (a -> b) -> a -> b
$ Text
"Could not verify JWT: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> a -> Text
forall a. Show a => a -> Text
tshow a
e

-- | parse the claims map from the JWT token or custom claims from the JWT config
parseClaimsMap ::
  MonadError QErr m =>
  -- | Unregistered JWT claims
  Jose.ClaimsSet ->
  -- | Claims config
  JWTClaims ->
  -- | Hasura claims and other claims
  m ClaimsMap
parseClaimsMap :: ClaimsSet -> JWTClaims -> m (HashMap SessionVariable Value)
parseClaimsMap ClaimsSet
claimsSet JWTClaims
jcxClaims = do
  let claimsJSON :: Value
claimsJSON = ClaimsSet -> Value
forall a. ToJSON a => a -> Value
J.toJSON ClaimsSet
claimsSet
      unregisteredClaims :: Map Text Value
unregisteredClaims = ClaimsSet
claimsSet ClaimsSet
-> Getting (Map Text Value) ClaimsSet (Map Text Value)
-> Map Text Value
forall s a. s -> Getting a s a -> a
^. Getting (Map Text Value) ClaimsSet (Map Text Value)
Lens' ClaimsSet (Map Text Value)
Jose.unregisteredClaims
  case JWTClaims
jcxClaims of
    -- when the user specifies the namespace of the hasura claims map,
    -- the hasura claims map *must* be specified in the unregistered claims
    JCNamespace JWTNamespace
namespace JWTClaimsFormat
claimsFormat -> do
      Value
claimsV <- (Maybe Value -> m Value -> m Value)
-> m Value -> Maybe Value -> m Value
forall a b c. (a -> b -> c) -> b -> a -> c
flip Maybe Value -> m Value -> m Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing (JWTNamespace -> m Value
forall (m :: * -> *) a. MonadError QErr m => JWTNamespace -> m a
claimsNotFound JWTNamespace
namespace) (Maybe Value -> m Value) -> Maybe Value -> m Value
forall a b. (a -> b) -> a -> b
$ case JWTNamespace
namespace of
        ClaimNs Text
k -> Text -> Map Text Value -> Maybe Value
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Text
k Map Text Value
unregisteredClaims
        ClaimNsPath JSONPath
path -> IResult Value -> Maybe Value
forall a. IResult a -> Maybe a
iResultToMaybe (IResult Value -> Maybe Value) -> IResult Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$ JSONPath -> Value -> IResult Value
executeJSONPath JSONPath
path (Map Text Value -> Value
forall a. ToJSON a => a -> Value
J.toJSON Map Text Value
unregisteredClaims)
      -- get hasura claims value as an object. parse from string possibly
      Object
claimsObject <- JWTNamespace -> JWTClaimsFormat -> Value -> m Object
forall (m :: * -> *).
MonadError QErr m =>
JWTNamespace -> JWTClaimsFormat -> Value -> m Object
parseObjectFromString JWTNamespace
namespace JWTClaimsFormat
claimsFormat Value
claimsV

      -- filter only x-hasura claims
      let claimsMap :: HashMap SessionVariable Value
claimsMap =
            [(SessionVariable, Value)] -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList ([(SessionVariable, Value)] -> HashMap SessionVariable Value)
-> [(SessionVariable, Value)] -> HashMap SessionVariable Value
forall a b. (a -> b) -> a -> b
$
              ((Text, Value) -> (SessionVariable, Value))
-> [(Text, Value)] -> [(SessionVariable, Value)]
forall a b. (a -> b) -> [a] -> [b]
map ((Text -> SessionVariable)
-> (Text, Value) -> (SessionVariable, Value)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Text -> SessionVariable
mkSessionVariable) ([(Text, Value)] -> [(SessionVariable, Value)])
-> [(Text, Value)] -> [(SessionVariable, Value)]
forall a b. (a -> b) -> a -> b
$
                ((Text, Value) -> Bool) -> [(Text, Value)] -> [(Text, Value)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Text -> Bool
isSessionVariable (Text -> Bool) -> ((Text, Value) -> Text) -> (Text, Value) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text, Value) -> Text
forall a b. (a, b) -> a
fst) ([(Text, Value)] -> [(Text, Value)])
-> [(Text, Value)] -> [(Text, Value)]
forall a b. (a -> b) -> a -> b
$
                  (Pair -> (Text, Value)) -> [Pair] -> [(Text, Value)]
forall a b. (a -> b) -> [a] -> [b]
map ((Key -> Text) -> Pair -> (Text, Value)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Key -> Text
K.toText) ([Pair] -> [(Text, Value)]) -> [Pair] -> [(Text, Value)]
forall a b. (a -> b) -> a -> b
$
                    Object -> [Pair]
forall v. KeyMap v -> [(Key, v)]
KM.toList Object
claimsObject

      HashMap SessionVariable Value -> m (HashMap SessionVariable Value)
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap SessionVariable Value
claimsMap
    JCMap JWTCustomClaimsMap
claimsConfig -> do
      let JWTCustomClaimsMap JWTCustomClaimsMapDefaultRole
defaultRoleClaimsMap JWTCustomClaimsMapAllowedRoles
allowedRolesClaimsMap CustomClaimsMap
otherClaimsMap = JWTCustomClaimsMap
claimsConfig

      [RoleName]
allowedRoles <- case JWTCustomClaimsMapAllowedRoles
allowedRolesClaimsMap of
        JWTCustomClaimsMapJSONPath JSONPath
allowedRolesJsonPath Maybe [RoleName]
defaultVal ->
          Maybe [RoleName] -> Maybe Value -> m [RoleName]
forall (m :: * -> *) a.
(MonadError QErr m, FromJSON a) =>
Maybe a -> Maybe Value -> m a
parseAllowedRolesClaim Maybe [RoleName]
defaultVal (Maybe Value -> m [RoleName]) -> Maybe Value -> m [RoleName]
forall a b. (a -> b) -> a -> b
$ IResult Value -> Maybe Value
forall a. IResult a -> Maybe a
iResultToMaybe (IResult Value -> Maybe Value) -> IResult Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$ JSONPath -> Value -> IResult Value
executeJSONPath JSONPath
allowedRolesJsonPath Value
claimsJSON
        JWTCustomClaimsMapStatic [RoleName]
staticAllowedRoles -> [RoleName] -> m [RoleName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [RoleName]
staticAllowedRoles

      RoleName
defaultRole <- case JWTCustomClaimsMapDefaultRole
defaultRoleClaimsMap of
        JWTCustomClaimsMapJSONPath JSONPath
defaultRoleJsonPath Maybe RoleName
defaultVal ->
          Maybe RoleName -> Maybe Value -> m RoleName
forall (m :: * -> *) a.
(MonadError QErr m, FromJSON a) =>
Maybe a -> Maybe Value -> m a
parseDefaultRoleClaim Maybe RoleName
defaultVal (Maybe Value -> m RoleName) -> Maybe Value -> m RoleName
forall a b. (a -> b) -> a -> b
$
            IResult Value -> Maybe Value
forall a. IResult a -> Maybe a
iResultToMaybe (IResult Value -> Maybe Value) -> IResult Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$
              JSONPath -> Value -> IResult Value
executeJSONPath JSONPath
defaultRoleJsonPath Value
claimsJSON
        JWTCustomClaimsMapStatic RoleName
staticDefaultRole -> RoleName -> m RoleName
forall (f :: * -> *) a. Applicative f => a -> f a
pure RoleName
staticDefaultRole

      HashMap SessionVariable Value
otherClaims <- ((SessionVariable -> JWTCustomClaimsMapValue -> m Value)
 -> CustomClaimsMap -> m (HashMap SessionVariable Value))
-> CustomClaimsMap
-> (SessionVariable -> JWTCustomClaimsMapValue -> m Value)
-> m (HashMap SessionVariable Value)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (SessionVariable -> JWTCustomClaimsMapValue -> m Value)
-> CustomClaimsMap -> m (HashMap SessionVariable Value)
forall (f :: * -> *) k v1 v2.
Applicative f =>
(k -> v1 -> f v2) -> HashMap k v1 -> f (HashMap k v2)
HM.traverseWithKey CustomClaimsMap
otherClaimsMap ((SessionVariable -> JWTCustomClaimsMapValue -> m Value)
 -> m (HashMap SessionVariable Value))
-> (SessionVariable -> JWTCustomClaimsMapValue -> m Value)
-> m (HashMap SessionVariable Value)
forall a b. (a -> b) -> a -> b
$ \SessionVariable
k JWTCustomClaimsMapValue
claimObj -> do
        let throwClaimErr :: m Value
throwClaimErr =
              Code -> Text -> m Value
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTInvalidClaims (Text -> m Value) -> Text -> m Value
forall a b. (a -> b) -> a -> b
$
                Text
"JWT claim from claims_map, "
                  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SessionVariable -> Text
sessionVariableToText SessionVariable
k
                  Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" not found"
        case JWTCustomClaimsMapValue
claimObj of
          JWTCustomClaimsMapJSONPath JSONPath
path Maybe Text
defaultVal ->
            IResult Value -> Maybe Value
forall a. IResult a -> Maybe a
iResultToMaybe (JSONPath -> Value -> IResult Value
executeJSONPath JSONPath
path Value
claimsJSON)
              Maybe Value -> Maybe Value -> Maybe Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` (Text -> Value
J.String (Text -> Value) -> Maybe Text -> Maybe Value
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Text
defaultVal)
              Maybe Value -> m Value -> m Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` m Value
throwClaimErr
          JWTCustomClaimsMapStatic Text
claimStaticValue -> Value -> m Value
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Value -> m Value) -> Value -> m Value
forall a b. (a -> b) -> a -> b
$ Text -> Value
J.String Text
claimStaticValue

      HashMap SessionVariable Value -> m (HashMap SessionVariable Value)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashMap SessionVariable Value
 -> m (HashMap SessionVariable Value))
-> HashMap SessionVariable Value
-> m (HashMap SessionVariable Value)
forall a b. (a -> b) -> a -> b
$
        [(SessionVariable, Value)] -> HashMap SessionVariable Value
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList
          [ (SessionVariable
allowedRolesClaim, [RoleName] -> Value
forall a. ToJSON a => a -> Value
J.toJSON [RoleName]
allowedRoles),
            (SessionVariable
defaultRoleClaim, RoleName -> Value
forall a. ToJSON a => a -> Value
J.toJSON RoleName
defaultRole)
          ]
          HashMap SessionVariable Value
-> HashMap SessionVariable Value -> HashMap SessionVariable Value
forall a. Semigroup a => a -> a -> a
<> HashMap SessionVariable Value
otherClaims
  where
    parseAllowedRolesClaim :: Maybe a -> Maybe Value -> m a
parseAllowedRolesClaim Maybe a
defaultVal = \case
      Maybe Value
Nothing ->
        Maybe a -> m a -> m a
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing Maybe a
defaultVal (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
          Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTRoleClaimMissing (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
"JWT claim does not contain " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SessionVariable -> Text
sessionVariableToText SessionVariable
allowedRolesClaim
      Just Value
v ->
        Value -> Text -> m a
forall a (m :: * -> *).
(FromJSON a, MonadError QErr m) =>
Value -> Text -> m a
parseJwtClaim Value
v (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
          Text
"invalid " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SessionVariable -> Text
sessionVariableToText SessionVariable
allowedRolesClaim
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; should be a list of roles"

    parseDefaultRoleClaim :: Maybe a -> Maybe Value -> m a
parseDefaultRoleClaim Maybe a
defaultVal = \case
      Maybe Value
Nothing ->
        Maybe a -> m a -> m a
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing Maybe a
defaultVal (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
          Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTRoleClaimMissing (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
"JWT claim does not contain " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SessionVariable -> Text
sessionVariableToText SessionVariable
defaultRoleClaim
      Just Value
v ->
        Value -> Text -> m a
forall a (m :: * -> *).
(FromJSON a, MonadError QErr m) =>
Value -> Text -> m a
parseJwtClaim Value
v (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
          Text
"invalid " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> SessionVariable -> Text
sessionVariableToText SessionVariable
defaultRoleClaim
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; should be a role"

    claimsNotFound :: JWTNamespace -> m a
claimsNotFound JWTNamespace
namespace =
      Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTInvalidClaims (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ case JWTNamespace
namespace of
        ClaimNsPath JSONPath
path ->
          Text
"claims not found at claims_namespace_path: '"
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> JSONPath -> Text
encodeJSONPath JSONPath
path
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'"
        ClaimNs Text
ns -> Text
"claims key: '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
ns Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' not found"

    parseObjectFromString :: JWTNamespace -> JWTClaimsFormat -> Value -> m Object
parseObjectFromString JWTNamespace
namespace JWTClaimsFormat
claimsFmt Value
jVal =
      case (JWTClaimsFormat
claimsFmt, Value
jVal) of
        (JWTClaimsFormat
JCFStringifiedJson, J.String Text
v) ->
          Either String Object -> (String -> m Object) -> m Object
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft (ByteString -> Either String Object
forall a. FromJSON a => ByteString -> Either String a
J.eitherDecodeStrict (ByteString -> Either String Object)
-> ByteString -> Either String Object
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 Text
v) (m Object -> String -> m Object
forall a b. a -> b -> a
const (m Object -> String -> m Object) -> m Object -> String -> m Object
forall a b. (a -> b) -> a -> b
$ Text -> m Object
forall a. Text -> m a
claimsErr (Text -> m Object) -> Text -> m Object
forall a b. (a -> b) -> a -> b
$ Text -> Text
strngfyErr Text
v)
        (JWTClaimsFormat
JCFStringifiedJson, Value
_) ->
          Text -> m Object
forall a. Text -> m a
claimsErr Text
"expecting a string when claims_format is stringified_json"
        (JWTClaimsFormat
JCFJson, J.Object Object
o) -> Object -> m Object
forall (m :: * -> *) a. Monad m => a -> m a
return Object
o
        (JWTClaimsFormat
JCFJson, Value
_) ->
          Text -> m Object
forall a. Text -> m a
claimsErr Text
"expecting a json object when claims_format is json"
      where
        strngfyErr :: Text -> Text
strngfyErr Text
v =
          let claimsLocation :: Text
claimsLocation = case JWTNamespace
namespace of
                ClaimNsPath JSONPath
path -> Text
"claims_namespace_path " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> JSONPath -> Text
encodeJSONPath JSONPath
path
                ClaimNs Text
ns -> Text
"claims_namespace " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
ns
           in Text
"expecting stringified json at: '"
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
claimsLocation
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"', but found: "
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
v

        claimsErr :: Text -> m a
claimsErr = Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTInvalidClaims

-- | Verify the JWT against given JWK
verifyJwt ::
  ( MonadError Jose.JWTError m,
    MonadIO m
  ) =>
  JWTCtx ->
  RawJWT ->
  m Jose.ClaimsSet
verifyJwt :: JWTCtx -> RawJWT -> m ClaimsSet
verifyJwt JWTCtx
ctx (RawJWT ByteString
rawJWT) = do
  JWKSet
key <- IO JWKSet -> m JWKSet
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO JWKSet -> m JWKSet) -> IO JWKSet -> m JWKSet
forall a b. (a -> b) -> a -> b
$ IORef JWKSet -> IO JWKSet
forall a. IORef a -> IO a
readIORef (IORef JWKSet -> IO JWKSet) -> IORef JWKSet -> IO JWKSet
forall a b. (a -> b) -> a -> b
$ JWTCtx -> IORef JWKSet
jcxKey JWTCtx
ctx
  SignedJWT
jwt <- ByteString -> m SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact ByteString
rawJWT
  UTCTime
t <- IO UTCTime -> m UTCTime
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
  JWTValidationSettings
-> JWKSet -> UTCTime -> SignedJWT -> m ClaimsSet
forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
Jose.verifyClaimsAt JWTValidationSettings
config JWKSet
key UTCTime
t SignedJWT
jwt
  where
    validationSettingsWithSkew :: JWTValidationSettings
validationSettingsWithSkew =
      case JWTCtx -> Maybe NominalDiffTime
jcxAllowedSkew JWTCtx
ctx of
        Just NominalDiffTime
allowedSkew -> (StringOrURI -> Bool) -> JWTValidationSettings
Jose.defaultJWTValidationSettings StringOrURI -> Bool
audCheck JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
-> NominalDiffTime
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  JWTValidationSettings
  JWTValidationSettings
  NominalDiffTime
  NominalDiffTime
forall s. HasAllowedSkew s => Lens' s NominalDiffTime
Jose.allowedSkew NominalDiffTime
allowedSkew
        -- In `Jose.defaultJWTValidationSettings`, the `allowedSkew` is 0
        Maybe NominalDiffTime
Nothing -> (StringOrURI -> Bool) -> JWTValidationSettings
Jose.defaultJWTValidationSettings StringOrURI -> Bool
audCheck

    config :: JWTValidationSettings
config = case JWTCtx -> Maybe StringOrURI
jcxIssuer JWTCtx
ctx of
      Maybe StringOrURI
Nothing -> JWTValidationSettings
validationSettingsWithSkew
      Just StringOrURI
iss -> JWTValidationSettings
validationSettingsWithSkew JWTValidationSettings
-> (JWTValidationSettings -> JWTValidationSettings)
-> JWTValidationSettings
forall a b. a -> (a -> b) -> b
& ASetter
  JWTValidationSettings
  JWTValidationSettings
  (StringOrURI -> Bool)
  (StringOrURI -> Bool)
-> (StringOrURI -> Bool)
-> JWTValidationSettings
-> JWTValidationSettings
forall s t a b. ASetter s t a b -> b -> s -> t
set ASetter
  JWTValidationSettings
  JWTValidationSettings
  (StringOrURI -> Bool)
  (StringOrURI -> Bool)
forall s. HasIssuerPredicate s => Lens' s (StringOrURI -> Bool)
Jose.issuerPredicate (StringOrURI -> StringOrURI -> Bool
forall a. Eq a => a -> a -> Bool
== StringOrURI
iss)
    audCheck :: StringOrURI -> Bool
audCheck StringOrURI
audience =
      -- dont perform the check if there are no audiences in the conf
      case JWTCtx -> Maybe Audience
jcxAudience JWTCtx
ctx of
        Maybe Audience
Nothing -> Bool
True
        Just (Jose.Audience [StringOrURI]
audiences) -> StringOrURI
audience StringOrURI -> [StringOrURI] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [StringOrURI]
audiences

instance J.ToJSON JWTConfig where
  toJSON :: JWTConfig -> Value
toJSON (JWTConfig Either JWK URI
keyOrUrl Maybe Audience
aud Maybe StringOrURI
iss JWTClaims
claims Maybe NominalDiffTime
allowedSkew Maybe JWTHeader
jwtHeader) =
    let keyOrUrlPairs :: [Pair]
keyOrUrlPairs = case Either JWK URI
keyOrUrl of
          Left JWK
_ ->
            [ Key
"type" Key -> Value -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Text -> Value
J.String Text
"<TYPE REDACTED>",
              Key
"key" Key -> Value -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Text -> Value
J.String Text
"<JWK REDACTED>"
            ]
          Right URI
url -> [Key
"jwk_url" Key -> URI -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= URI
url]
        claimsPairs :: [Pair]
claimsPairs = case JWTClaims
claims of
          JCNamespace JWTNamespace
namespace JWTClaimsFormat
claimsFormat ->
            let namespacePairs :: [Pair]
namespacePairs = case JWTNamespace
namespace of
                  ClaimNsPath JSONPath
nsPath ->
                    [Key
"claims_namespace_path" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= JSONPath -> Text
encodeJSONPath JSONPath
nsPath]
                  ClaimNs Text
ns -> [Key
"claims_namespace" Key -> Value -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Text -> Value
J.String Text
ns]
             in [Pair]
namespacePairs [Pair] -> [Pair] -> [Pair]
forall a. Semigroup a => a -> a -> a
<> [Key
"claims_format" Key -> JWTClaimsFormat -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= JWTClaimsFormat
claimsFormat]
          JCMap JWTCustomClaimsMap
claimsMap -> [Key
"claims_map" Key -> JWTCustomClaimsMap -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= JWTCustomClaimsMap
claimsMap]
     in [Pair] -> Value
J.object ([Pair] -> Value) -> [Pair] -> Value
forall a b. (a -> b) -> a -> b
$
          [Pair]
keyOrUrlPairs
            [Pair] -> [Pair] -> [Pair]
forall a. Semigroup a => a -> a -> a
<> [ Key
"audience" Key -> Maybe Audience -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Maybe Audience
aud,
                 Key
"issuer" Key -> Maybe StringOrURI -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Maybe StringOrURI
iss,
                 Key
"header" Key -> Maybe JWTHeader -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Maybe JWTHeader
jwtHeader
               ]
            [Pair] -> [Pair] -> [Pair]
forall a. Semigroup a => a -> a -> a
<> [Pair]
claimsPairs
            [Pair] -> [Pair] -> [Pair]
forall a. Semigroup a => a -> a -> a
<> ([Pair]
-> (NominalDiffTime -> [Pair]) -> Maybe NominalDiffTime -> [Pair]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\NominalDiffTime
skew -> [Key
"allowed_skew" Key -> NominalDiffTime -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= NominalDiffTime
skew]) Maybe NominalDiffTime
allowedSkew)

-- | Parse from a json string like:
-- | `{"type": "RS256", "key": "<PEM-encoded-public-key-or-X509-cert>"}`
-- | to JWTConfig
instance J.FromJSON JWTConfig where
  parseJSON :: Value -> Parser JWTConfig
parseJSON = String -> (Object -> Parser JWTConfig) -> Value -> Parser JWTConfig
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"JWTConfig" ((Object -> Parser JWTConfig) -> Value -> Parser JWTConfig)
-> (Object -> Parser JWTConfig) -> Value -> Parser JWTConfig
forall a b. (a -> b) -> a -> b
$ \Object
o -> do
    Maybe Text
mRawKey <- Object
o Object -> Key -> Parser (Maybe Text)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"key"
    Maybe Text
claimsNs <- Object
o Object -> Key -> Parser (Maybe Text)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"claims_namespace"
    Maybe Text
claimsNsPath <- Object
o Object -> Key -> Parser (Maybe Text)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"claims_namespace_path"
    Maybe Audience
aud <- Object
o Object -> Key -> Parser (Maybe Audience)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"audience"
    Maybe StringOrURI
iss <- Object
o Object -> Key -> Parser (Maybe StringOrURI)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"issuer"
    Maybe URI
jwkUrl <- Object
o Object -> Key -> Parser (Maybe URI)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"jwk_url"
    JWTClaimsFormat
claimsFormat <- Object
o Object -> Key -> Parser (Maybe JWTClaimsFormat)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"claims_format" Parser (Maybe JWTClaimsFormat)
-> JWTClaimsFormat -> Parser JWTClaimsFormat
forall a. Parser (Maybe a) -> a -> Parser a
J..!= JWTClaimsFormat
defaultClaimsFormat
    Maybe JWTCustomClaimsMap
claimsMap <- Object
o Object -> Key -> Parser (Maybe JWTCustomClaimsMap)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"claims_map"
    Maybe NominalDiffTime
allowedSkew <- Object
o Object -> Key -> Parser (Maybe NominalDiffTime)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"allowed_skew"
    Maybe JWTHeader
jwtHeader <- Object
o Object -> Key -> Parser (Maybe JWTHeader)
forall a. FromJSON a => Object -> Key -> Parser (Maybe a)
J..:? Key
"header"

    JWTNamespace
hasuraClaimsNs <-
      case (Maybe Text
claimsNsPath, Maybe Text
claimsNs) of
        (Maybe Text
Nothing, Maybe Text
Nothing) -> JWTNamespace -> Parser JWTNamespace
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JWTNamespace -> Parser JWTNamespace)
-> JWTNamespace -> Parser JWTNamespace
forall a b. (a -> b) -> a -> b
$ Text -> JWTNamespace
ClaimNs Text
defaultClaimsNamespace
        (Just Text
nsPath, Maybe Text
Nothing) -> (Text -> Parser JWTNamespace)
-> (JSONPath -> Parser JWTNamespace)
-> Either Text JSONPath
-> Parser JWTNamespace
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> Parser JWTNamespace
forall (m :: * -> *) a. MonadFail m => Text -> m a
failJSONPathParsing (JWTNamespace -> Parser JWTNamespace
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTNamespace -> Parser JWTNamespace)
-> (JSONPath -> JWTNamespace) -> JSONPath -> Parser JWTNamespace
forall b c a. (b -> c) -> (a -> b) -> a -> c
. JSONPath -> JWTNamespace
ClaimNsPath) (Either Text JSONPath -> Parser JWTNamespace)
-> (Text -> Either Text JSONPath) -> Text -> Parser JWTNamespace
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Either Text JSONPath
parseJSONPath (Text -> Parser JWTNamespace) -> Text -> Parser JWTNamespace
forall a b. (a -> b) -> a -> b
$ Text
nsPath
        (Maybe Text
Nothing, Just Text
ns) -> JWTNamespace -> Parser JWTNamespace
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTNamespace -> Parser JWTNamespace)
-> JWTNamespace -> Parser JWTNamespace
forall a b. (a -> b) -> a -> b
$ Text -> JWTNamespace
ClaimNs Text
ns
        (Just Text
_, Just Text
_) -> String -> Parser JWTNamespace
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"claims_namespace and claims_namespace_path both cannot be set"

    Either JWK URI
keyOrUrl <- case (Maybe Text
mRawKey, Maybe URI
jwkUrl) of
      (Maybe Text
Nothing, Maybe URI
Nothing) -> String -> Parser (Either JWK URI)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"key and jwk_url both cannot be empty"
      (Just Text
_, Just URI
_) -> String -> Parser (Either JWK URI)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"key, jwk_url both cannot be present"
      (Just Text
rawKey, Maybe URI
Nothing) -> do
        Text
keyType <- Object
o Object -> Key -> Parser Text
forall a. FromJSON a => Object -> Key -> Parser a
J..: Key
"type"
        JWK
key <- Text -> Text -> Parser JWK
parseKey Text
keyType Text
rawKey
        Either JWK URI -> Parser (Either JWK URI)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either JWK URI -> Parser (Either JWK URI))
-> Either JWK URI -> Parser (Either JWK URI)
forall a b. (a -> b) -> a -> b
$ JWK -> Either JWK URI
forall a b. a -> Either a b
Left JWK
key
      (Maybe Text
Nothing, Just URI
url) -> Either JWK URI -> Parser (Either JWK URI)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either JWK URI -> Parser (Either JWK URI))
-> Either JWK URI -> Parser (Either JWK URI)
forall a b. (a -> b) -> a -> b
$ URI -> Either JWK URI
forall a b. b -> Either a b
Right URI
url

    let jwtClaims :: JWTClaims
jwtClaims = JWTClaims
-> (JWTCustomClaimsMap -> JWTClaims)
-> Maybe JWTCustomClaimsMap
-> JWTClaims
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (JWTNamespace -> JWTClaimsFormat -> JWTClaims
JCNamespace JWTNamespace
hasuraClaimsNs JWTClaimsFormat
claimsFormat) JWTCustomClaimsMap -> JWTClaims
JCMap Maybe JWTCustomClaimsMap
claimsMap

    JWTConfig -> Parser JWTConfig
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JWTConfig -> Parser JWTConfig) -> JWTConfig -> Parser JWTConfig
forall a b. (a -> b) -> a -> b
$ Either JWK URI
-> Maybe Audience
-> Maybe StringOrURI
-> JWTClaims
-> Maybe NominalDiffTime
-> Maybe JWTHeader
-> JWTConfig
JWTConfig Either JWK URI
keyOrUrl Maybe Audience
aud Maybe StringOrURI
iss JWTClaims
jwtClaims Maybe NominalDiffTime
allowedSkew Maybe JWTHeader
jwtHeader
    where
      parseKey :: Text -> Text -> Parser JWK
parseKey Text
keyType Text
rawKey =
        case Text
keyType of
          Text
"HS256" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Int64 -> Either Text JWK
parseHmacKey Text
rawKey Int64
256
          Text
"HS384" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Int64 -> Either Text JWK
parseHmacKey Text
rawKey Int64
384
          Text
"HS512" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Int64 -> Either Text JWK
parseHmacKey Text
rawKey Int64
512
          Text
"RS256" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Either Text JWK
parseRsaKey Text
rawKey
          Text
"RS384" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Either Text JWK
parseRsaKey Text
rawKey
          Text
"RS512" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Either Text JWK
parseRsaKey Text
rawKey
          Text
"Ed25519" -> Either Text JWK -> Parser JWK
forall a. Either Text a -> Parser a
runEither (Either Text JWK -> Parser JWK) -> Either Text JWK -> Parser JWK
forall a b. (a -> b) -> a -> b
$ Text -> Either Text JWK
parseEdDSAKey Text
rawKey
          -- TODO(from master): support ES256, ES384, ES512, PS256, PS384, Ed448 (JOSE doesn't support it as of now)
          Text
_ -> String -> Parser JWK
forall (m :: * -> *) a. MonadFail m => String -> m a
invalidJwk (String
"Key type: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
keyType String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is not supported")

      runEither :: Either Text a -> Parser a
runEither = (Text -> Parser a) -> (a -> Parser a) -> Either Text a -> Parser a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Parser a
forall (m :: * -> *) a. MonadFail m => String -> m a
invalidJwk (String -> Parser a) -> (Text -> String) -> Text -> Parser a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack) a -> Parser a
forall (m :: * -> *) a. Monad m => a -> m a
return

      invalidJwk :: String -> m a
invalidJwk String
msg = String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Invalid JWK: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg)

      failJSONPathParsing :: Text -> m a
failJSONPathParsing Text
err = String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> (Text -> String) -> Text -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
"invalid JSON path claims_namespace_path error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
err

-- parse x-hasura-allowed-roles, x-hasura-default-role from JWT claims
parseHasuraClaims :: forall m. (MonadError QErr m) => ClaimsMap -> m HasuraClaims
parseHasuraClaims :: HashMap SessionVariable Value -> m HasuraClaims
parseHasuraClaims HashMap SessionVariable Value
claimsMap = do
  [RoleName] -> RoleName -> HasuraClaims
HasuraClaims
    ([RoleName] -> RoleName -> HasuraClaims)
-> m [RoleName] -> m (RoleName -> HasuraClaims)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SessionVariable -> Text -> m [RoleName]
forall a. FromJSON a => SessionVariable -> Text -> m a
parseClaim SessionVariable
allowedRolesClaim Text
"should be a list of roles"
    m (RoleName -> HasuraClaims) -> m RoleName -> m HasuraClaims
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SessionVariable -> Text -> m RoleName
forall a. FromJSON a => SessionVariable -> Text -> m a
parseClaim SessionVariable
defaultRoleClaim Text
"should be a single role name"
  where
    parseClaim :: J.FromJSON a => SessionVariable -> Text -> m a
    parseClaim :: SessionVariable -> Text -> m a
parseClaim SessionVariable
claim Text
hint = do
      Value
claimV <- Maybe Value -> m Value -> m Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing (SessionVariable -> HashMap SessionVariable Value -> Maybe Value
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup SessionVariable
claim HashMap SessionVariable Value
claimsMap) m Value
missingClaim
      Value -> Text -> m a
forall a (m :: * -> *).
(FromJSON a, MonadError QErr m) =>
Value -> Text -> m a
parseJwtClaim Value
claimV (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
"invalid " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
claimText Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"; " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
hint
      where
        missingClaim :: m Value
missingClaim = Code -> Text -> m Value
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTRoleClaimMissing (Text -> m Value) -> Text -> m Value
forall a b. (a -> b) -> a -> b
$ Text
"JWT claim does not contain " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
claimText
        claimText :: Text
claimText = SessionVariable -> Text
sessionVariableToText SessionVariable
claim

-- Utility:
parseJwtClaim :: (J.FromJSON a, MonadError QErr m) => J.Value -> Text -> m a
parseJwtClaim :: Value -> Text -> m a
parseJwtClaim Value
v Text
errMsg =
  case Value -> Result a
forall a. FromJSON a => Value -> Result a
J.fromJSON Value
v of
    J.Success a
val -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
val
    J.Error String
e -> Code -> Text -> m a
forall (m :: * -> *) a. QErrM m => Code -> Text -> m a
throw400 Code
JWTInvalidClaims (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
errMsg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack String
e