{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TemplateHaskell #-}
module Hasura.Server.Cors
( CorsConfig (..),
CorsPolicy (..),
parseOrigin,
readCorsDomains,
mkDefaultCorsPolicy,
isCorsDisabled,
Domains (..),
inWildcardList,
)
where
import Control.Applicative (optional)
import Data.Aeson ((.:))
import Data.Aeson qualified as J
import Data.Aeson.TH qualified as J
import Data.Attoparsec.Text qualified as AT
import Data.Char qualified as C
import Data.HashSet qualified as Set
import Data.Text qualified as T
import Hasura.Prelude
import Hasura.Server.Utils (fmapL)
data DomainParts = DomainParts
{ DomainParts -> Text
wdScheme :: !Text,
DomainParts -> Text
wdHost :: !Text,
DomainParts -> Maybe Int
wdPort :: !(Maybe Int)
}
deriving (Int -> DomainParts -> ShowS
[DomainParts] -> ShowS
DomainParts -> String
(Int -> DomainParts -> ShowS)
-> (DomainParts -> String)
-> ([DomainParts] -> ShowS)
-> Show DomainParts
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DomainParts] -> ShowS
$cshowList :: [DomainParts] -> ShowS
show :: DomainParts -> String
$cshow :: DomainParts -> String
showsPrec :: Int -> DomainParts -> ShowS
$cshowsPrec :: Int -> DomainParts -> ShowS
Show, DomainParts -> DomainParts -> Bool
(DomainParts -> DomainParts -> Bool)
-> (DomainParts -> DomainParts -> Bool) -> Eq DomainParts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DomainParts -> DomainParts -> Bool
$c/= :: DomainParts -> DomainParts -> Bool
== :: DomainParts -> DomainParts -> Bool
$c== :: DomainParts -> DomainParts -> Bool
Eq, (forall x. DomainParts -> Rep DomainParts x)
-> (forall x. Rep DomainParts x -> DomainParts)
-> Generic DomainParts
forall x. Rep DomainParts x -> DomainParts
forall x. DomainParts -> Rep DomainParts x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep DomainParts x -> DomainParts
$cfrom :: forall x. DomainParts -> Rep DomainParts x
Generic, Int -> DomainParts -> Int
DomainParts -> Int
(Int -> DomainParts -> Int)
-> (DomainParts -> Int) -> Hashable DomainParts
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: DomainParts -> Int
$chash :: DomainParts -> Int
hashWithSalt :: Int -> DomainParts -> Int
$chashWithSalt :: Int -> DomainParts -> Int
Hashable)
$(J.deriveJSON hasuraJSON ''DomainParts)
data Domains = Domains
{ Domains -> HashSet Text
dmFqdns :: !(Set.HashSet Text),
Domains -> HashSet DomainParts
dmWildcards :: !(Set.HashSet DomainParts)
}
deriving (Int -> Domains -> ShowS
[Domains] -> ShowS
Domains -> String
(Int -> Domains -> ShowS)
-> (Domains -> String) -> ([Domains] -> ShowS) -> Show Domains
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Domains] -> ShowS
$cshowList :: [Domains] -> ShowS
show :: Domains -> String
$cshow :: Domains -> String
showsPrec :: Int -> Domains -> ShowS
$cshowsPrec :: Int -> Domains -> ShowS
Show, Domains -> Domains -> Bool
(Domains -> Domains -> Bool)
-> (Domains -> Domains -> Bool) -> Eq Domains
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Domains -> Domains -> Bool
$c/= :: Domains -> Domains -> Bool
== :: Domains -> Domains -> Bool
$c== :: Domains -> Domains -> Bool
Eq)
$(J.deriveJSON hasuraJSON ''Domains)
data CorsConfig
= CCAllowAll
| CCAllowedOrigins Domains
| CCDisabled Bool
deriving (Int -> CorsConfig -> ShowS
[CorsConfig] -> ShowS
CorsConfig -> String
(Int -> CorsConfig -> ShowS)
-> (CorsConfig -> String)
-> ([CorsConfig] -> ShowS)
-> Show CorsConfig
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CorsConfig] -> ShowS
$cshowList :: [CorsConfig] -> ShowS
show :: CorsConfig -> String
$cshow :: CorsConfig -> String
showsPrec :: Int -> CorsConfig -> ShowS
$cshowsPrec :: Int -> CorsConfig -> ShowS
Show, CorsConfig -> CorsConfig -> Bool
(CorsConfig -> CorsConfig -> Bool)
-> (CorsConfig -> CorsConfig -> Bool) -> Eq CorsConfig
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CorsConfig -> CorsConfig -> Bool
$c/= :: CorsConfig -> CorsConfig -> Bool
== :: CorsConfig -> CorsConfig -> Bool
$c== :: CorsConfig -> CorsConfig -> Bool
Eq)
instance J.ToJSON CorsConfig where
toJSON :: CorsConfig -> Value
toJSON CorsConfig
c = case CorsConfig
c of
CCDisabled Bool
wsrc -> Bool -> Value -> Maybe Bool -> Value
toJ Bool
True Value
J.Null (Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
wsrc)
CorsConfig
CCAllowAll -> Bool -> Value -> Maybe Bool -> Value
toJ Bool
False (Text -> Value
J.String Text
"*") Maybe Bool
forall a. Maybe a
Nothing
CCAllowedOrigins Domains
d -> Bool -> Value -> Maybe Bool -> Value
toJ Bool
False (Domains -> Value
forall a. ToJSON a => a -> Value
J.toJSON Domains
d) Maybe Bool
forall a. Maybe a
Nothing
where
toJ :: Bool -> J.Value -> Maybe Bool -> J.Value
toJ :: Bool -> Value -> Maybe Bool -> Value
toJ Bool
dis Value
origs Maybe Bool
mWsRC =
[Pair] -> Value
J.object
[ Key
"disabled" Key -> Bool -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Bool
dis,
Key
"ws_read_cookie" Key -> Maybe Bool -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Maybe Bool
mWsRC,
Key
"allowed_origins" Key -> Value -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
J..= Value
origs
]
instance J.FromJSON CorsConfig where
parseJSON :: Value -> Parser CorsConfig
parseJSON = String
-> (Object -> Parser CorsConfig) -> Value -> Parser CorsConfig
forall a. String -> (Object -> Parser a) -> Value -> Parser a
J.withObject String
"cors config" \Object
o -> do
let parseAllowAll :: a -> f CorsConfig
parseAllowAll a
"*" = CorsConfig -> f CorsConfig
forall (f :: * -> *) a. Applicative f => a -> f a
pure CorsConfig
CCAllowAll
parseAllowAll a
_ = String -> f CorsConfig
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"unexpected string"
Object
o Object -> Key -> Parser Bool
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"disabled" Parser Bool -> (Bool -> Parser CorsConfig) -> Parser CorsConfig
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Bool
True -> Bool -> CorsConfig
CCDisabled (Bool -> CorsConfig) -> Parser Bool -> Parser CorsConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Object
o Object -> Key -> Parser Bool
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"ws_read_cookie"
Bool
False ->
Object
o Object -> Key -> Parser Value
forall a. FromJSON a => Object -> Key -> Parser a
.: Key
"allowed_origins" Parser Value -> (Value -> Parser CorsConfig) -> Parser CorsConfig
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Value
v ->
String -> (Text -> Parser CorsConfig) -> Value -> Parser CorsConfig
forall a. String -> (Text -> Parser a) -> Value -> Parser a
J.withText String
"origins" Text -> Parser CorsConfig
forall a (f :: * -> *).
(Eq a, IsString a, MonadFail f) =>
a -> f CorsConfig
parseAllowAll Value
v
Parser CorsConfig -> Parser CorsConfig -> Parser CorsConfig
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Domains -> CorsConfig
CCAllowedOrigins (Domains -> CorsConfig) -> Parser Domains -> Parser CorsConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser Domains
forall a. FromJSON a => Value -> Parser a
J.parseJSON Value
v
isCorsDisabled :: CorsConfig -> Bool
isCorsDisabled :: CorsConfig -> Bool
isCorsDisabled = \case
CCDisabled Bool
_ -> Bool
True
CorsConfig
_ -> Bool
False
readCorsDomains :: String -> Either String CorsConfig
readCorsDomains :: String -> Either String CorsConfig
readCorsDomains String
str
| String
str String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"*" = CorsConfig -> Either String CorsConfig
forall (f :: * -> *) a. Applicative f => a -> f a
pure CorsConfig
CCAllowAll
| Bool
otherwise = do
let domains :: [Text]
domains = (Text -> Text) -> [Text] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Text -> Text
T.strip ([Text] -> [Text]) -> [Text] -> [Text]
forall a b. (a -> b) -> a -> b
$ Text -> Text -> [Text]
T.splitOn Text
"," (String -> Text
T.pack String
str)
[Either Text DomainParts]
pDomains <- (Text -> Either String (Either Text DomainParts))
-> [Text] -> Either String [Either Text DomainParts]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Text -> Either String (Either Text DomainParts)
parseOptWildcardDomain [Text]
domains
let ([Text]
fqdns, [DomainParts]
wcs) = ([Either Text DomainParts] -> [Text]
forall a b. [Either a b] -> [a]
lefts [Either Text DomainParts]
pDomains, [Either Text DomainParts] -> [DomainParts]
forall a b. [Either a b] -> [b]
rights [Either Text DomainParts]
pDomains)
CorsConfig -> Either String CorsConfig
forall (m :: * -> *) a. Monad m => a -> m a
return (CorsConfig -> Either String CorsConfig)
-> CorsConfig -> Either String CorsConfig
forall a b. (a -> b) -> a -> b
$ Domains -> CorsConfig
CCAllowedOrigins (Domains -> CorsConfig) -> Domains -> CorsConfig
forall a b. (a -> b) -> a -> b
$ HashSet Text -> HashSet DomainParts -> Domains
Domains ([Text] -> HashSet Text
forall a. (Eq a, Hashable a) => [a] -> HashSet a
Set.fromList [Text]
fqdns) ([DomainParts] -> HashSet DomainParts
forall a. (Eq a, Hashable a) => [a] -> HashSet a
Set.fromList [DomainParts]
wcs)
data CorsPolicy = CorsPolicy
{ CorsPolicy -> CorsConfig
cpConfig :: !CorsConfig,
CorsPolicy -> [Text]
cpMethods :: ![Text],
CorsPolicy -> Int
cpMaxAge :: !Int
}
deriving (Int -> CorsPolicy -> ShowS
[CorsPolicy] -> ShowS
CorsPolicy -> String
(Int -> CorsPolicy -> ShowS)
-> (CorsPolicy -> String)
-> ([CorsPolicy] -> ShowS)
-> Show CorsPolicy
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CorsPolicy] -> ShowS
$cshowList :: [CorsPolicy] -> ShowS
show :: CorsPolicy -> String
$cshow :: CorsPolicy -> String
showsPrec :: Int -> CorsPolicy -> ShowS
$cshowsPrec :: Int -> CorsPolicy -> ShowS
Show, CorsPolicy -> CorsPolicy -> Bool
(CorsPolicy -> CorsPolicy -> Bool)
-> (CorsPolicy -> CorsPolicy -> Bool) -> Eq CorsPolicy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CorsPolicy -> CorsPolicy -> Bool
$c/= :: CorsPolicy -> CorsPolicy -> Bool
== :: CorsPolicy -> CorsPolicy -> Bool
$c== :: CorsPolicy -> CorsPolicy -> Bool
Eq)
mkDefaultCorsPolicy :: CorsConfig -> CorsPolicy
mkDefaultCorsPolicy :: CorsConfig -> CorsPolicy
mkDefaultCorsPolicy CorsConfig
cfg =
CorsPolicy :: CorsConfig -> [Text] -> Int -> CorsPolicy
CorsPolicy
{ cpConfig :: CorsConfig
cpConfig = CorsConfig
cfg,
cpMethods :: [Text]
cpMethods = [Text
"GET", Text
"POST", Text
"PUT", Text
"PATCH", Text
"DELETE", Text
"OPTIONS"],
cpMaxAge :: Int
cpMaxAge = Int
1728000
}
inWildcardList :: Domains -> Text -> Bool
inWildcardList :: Domains -> Text -> Bool
inWildcardList (Domains HashSet Text
_ HashSet DomainParts
wildcards) Text
origin =
(String -> Bool)
-> (DomainParts -> Bool) -> Either String DomainParts -> Bool
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Bool -> String -> Bool
forall a b. a -> b -> a
const Bool
False) (DomainParts -> HashSet DomainParts -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`Set.member` HashSet DomainParts
wildcards) (Either String DomainParts -> Bool)
-> Either String DomainParts -> Bool
forall a b. (a -> b) -> a -> b
$ Text -> Either String DomainParts
parseOrigin Text
origin
runParser :: AT.Parser a -> Text -> Either String a
runParser :: Parser a -> Text -> Either String a
runParser = Parser a -> Text -> Either String a
forall a. Parser a -> Text -> Either String a
AT.parseOnly
parseOrigin :: Text -> Either String DomainParts
parseOrigin :: Text -> Either String DomainParts
parseOrigin = Parser DomainParts -> Text -> Either String DomainParts
forall a. Parser a -> Text -> Either String a
runParser Parser DomainParts
originParser
originParser :: AT.Parser DomainParts
originParser :: Parser DomainParts
originParser =
Maybe (Parser Text) -> Parser DomainParts
domainParser (Parser Text -> Maybe (Parser Text)
forall a. a -> Maybe a
Just Parser Text
ignoreSubdomain)
where
ignoreSubdomain :: Parser Text
ignoreSubdomain = do
Text
s <- (Char -> Bool) -> Parser Text
AT.takeTill (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'.')
Parser Text Char -> Parser Text ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser Text Char -> Parser Text ())
-> Parser Text Char -> Parser Text ()
forall a b. (a -> b) -> a -> b
$ Char -> Parser Text Char
AT.char Char
'.'
Text -> Parser Text
forall (m :: * -> *) a. Monad m => a -> m a
return Text
s
parseOptWildcardDomain :: Text -> Either String (Either Text DomainParts)
parseOptWildcardDomain :: Text -> Either String (Either Text DomainParts)
parseOptWildcardDomain Text
d =
ShowS
-> Either String (Either Text DomainParts)
-> Either String (Either Text DomainParts)
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL (String -> ShowS
forall a b. a -> b -> a
const String
errMsg) (Either String (Either Text DomainParts)
-> Either String (Either Text DomainParts))
-> Either String (Either Text DomainParts)
-> Either String (Either Text DomainParts)
forall a b. (a -> b) -> a -> b
$ Parser (Either Text DomainParts)
-> Text -> Either String (Either Text DomainParts)
forall a. Parser a -> Text -> Either String a
runParser Parser (Either Text DomainParts)
optWildcardDomainParser Text
d
where
optWildcardDomainParser :: AT.Parser (Either Text DomainParts)
optWildcardDomainParser :: Parser (Either Text DomainParts)
optWildcardDomainParser =
DomainParts -> Either Text DomainParts
forall a b. b -> Either a b
Right (DomainParts -> Either Text DomainParts)
-> Parser DomainParts -> Parser (Either Text DomainParts)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser DomainParts
wildcardDomainParser Parser (Either Text DomainParts)
-> Parser (Either Text DomainParts)
-> Parser (Either Text DomainParts)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Text -> Either Text DomainParts
forall a b. a -> Either a b
Left (Text -> Either Text DomainParts)
-> Parser Text -> Parser (Either Text DomainParts)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Text
fqdnParser
errMsg :: String
errMsg = String
"invalid domain: '" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
d String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"'. " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
helpMsg
helpMsg :: String
helpMsg =
String
"All domains should have scheme + (optional wildcard) host + "
String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"(optional port)"
wildcardDomainParser :: AT.Parser DomainParts
wildcardDomainParser :: Parser DomainParts
wildcardDomainParser = Maybe (Parser Text) -> Parser DomainParts
domainParser (Maybe (Parser Text) -> Parser DomainParts)
-> Maybe (Parser Text) -> Parser DomainParts
forall a b. (a -> b) -> a -> b
$ Parser Text -> Maybe (Parser Text)
forall a. a -> Maybe a
Just (Text -> Parser Text
AT.string Text
"*" Parser Text -> Parser Text -> Parser Text
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Text -> Parser Text
AT.string Text
".")
fqdnParser :: AT.Parser Text
fqdnParser :: Parser Text
fqdnParser = do
(DomainParts Text
scheme Text
host Maybe Int
port) <- Maybe (Parser Text) -> Parser DomainParts
domainParser Maybe (Parser Text)
forall a. Maybe a
Nothing
let sPort :: Text
sPort = Text -> (Int -> Text) -> Maybe Int -> Text
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Text
"" (\Int
p -> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
p) Maybe Int
port
Text -> Parser Text
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Parser Text) -> Text -> Parser Text
forall a b. (a -> b) -> a -> b
$ Text
scheme Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
host Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
sPort
domainParser :: Maybe (AT.Parser Text) -> AT.Parser DomainParts
domainParser :: Maybe (Parser Text) -> Parser DomainParts
domainParser Maybe (Parser Text)
parser = do
Text
scheme <- Parser Text
schemeParser
Maybe (Parser Text)
-> (Parser Text -> Parser Text ()) -> Parser Text ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe (Parser Text)
parser Parser Text -> Parser Text ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void
Text
host <- Parser Text
hostPortParser
Maybe Int
port <- Parser Text Int -> Parser Text (Maybe Int)
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional Parser Text Int
portParser
DomainParts -> Parser DomainParts
forall (m :: * -> *) a. Monad m => a -> m a
return (DomainParts -> Parser DomainParts)
-> DomainParts -> Parser DomainParts
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Maybe Int -> DomainParts
DomainParts Text
scheme Text
host Maybe Int
port
where
schemeParser :: AT.Parser Text
schemeParser :: Parser Text
schemeParser = do
Text
scheme <- (Char -> Bool) -> Parser Text
AT.takeWhile1 (\Char
x -> Char -> Bool
C.isAlphaNum Char
x Bool -> Bool -> Bool
|| Char -> String -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem Char
x [Char
'+', Char
'.', Char
'-'])
Text
sep <- Text -> Parser Text
AT.string Text
"://"
Text -> Parser Text
forall (m :: * -> *) a. Monad m => a -> m a
return (Text -> Parser Text) -> Text -> Parser Text
forall a b. (a -> b) -> a -> b
$ Text
scheme Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
sep
hostPortParser :: AT.Parser Text
hostPortParser :: Parser Text
hostPortParser = Parser Text
hostWithPortParser Parser Text -> Parser Text -> Parser Text
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Text
AT.takeText
hostWithPortParser :: AT.Parser Text
hostWithPortParser :: Parser Text
hostWithPortParser = do
Text
h <- (Char -> Bool) -> Parser Text
AT.takeWhile1 (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
':')
Parser Text Char -> Parser Text ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parser Text Char -> Parser Text ())
-> Parser Text Char -> Parser Text ()
forall a b. (a -> b) -> a -> b
$ Char -> Parser Text Char
AT.char Char
':'
Text -> Parser Text
forall (m :: * -> *) a. Monad m => a -> m a
return Text
h
portParser :: AT.Parser Int
portParser :: Parser Text Int
portParser = Parser Text Int
forall a. Integral a => Parser a
AT.decimal