module Hasura.Server.Middleware
  ( corsMiddleware,
  )
where

import Control.Applicative
import Data.ByteString qualified as B
import Data.CaseInsensitive qualified as CI
import Data.Text.Encoding qualified as TE
import Hasura.Prelude
import Hasura.Server.Cors
import Hasura.Server.Utils
import Network.HTTP.Types qualified as HTTP
import Network.Wai

corsMiddleware :: IO CorsPolicy -> Middleware
corsMiddleware :: IO CorsPolicy -> Middleware
corsMiddleware IO CorsPolicy
getPolicy Application
app Request
req Response -> IO ResponseReceived
sendResp = do
  let origin :: Maybe ByteString
origin = HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Origin" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
  CorsPolicy
policy <- IO CorsPolicy
getPolicy
  IO ResponseReceived
-> (ByteString -> IO ResponseReceived)
-> Maybe ByteString
-> IO ResponseReceived
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Application
app Request
req Response -> IO ResponseReceived
sendResp) (CorsPolicy -> ByteString -> IO ResponseReceived
handleCors CorsPolicy
policy) Maybe ByteString
origin
  where
    handleCors :: CorsPolicy -> ByteString -> IO ResponseReceived
handleCors CorsPolicy
policy ByteString
origin = case CorsPolicy -> CorsConfig
cpConfig CorsPolicy
policy of
      CCDisabled Bool
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResp
      CorsConfig
CCAllowAll -> ByteString -> CorsPolicy -> IO ResponseReceived
sendCors ByteString
origin CorsPolicy
policy
      CCAllowedOrigins Domains
ds
        -- if the origin is in our cors domains, send cors headers
        | ByteString -> Text
bsToTxt ByteString
origin Text -> HashSet Text -> Bool
forall a. Eq a => a -> HashSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Domains -> HashSet Text
dmFqdns Domains
ds -> ByteString -> CorsPolicy -> IO ResponseReceived
sendCors ByteString
origin CorsPolicy
policy
        -- if current origin is part of wildcard domain list, send cors
        | Domains -> Text -> Bool
inWildcardList Domains
ds (ByteString -> Text
bsToTxt ByteString
origin) -> ByteString -> CorsPolicy -> IO ResponseReceived
sendCors ByteString
origin CorsPolicy
policy
        -- otherwise don't send cors headers
        | Bool
otherwise -> Application
app Request
req Response -> IO ResponseReceived
sendResp

    sendCors :: B.ByteString -> CorsPolicy -> IO ResponseReceived
    sendCors :: ByteString -> CorsPolicy -> IO ResponseReceived
sendCors ByteString
origin CorsPolicy
policy =
      case Request -> ByteString
requestMethod Request
req of
        ByteString
"OPTIONS" -> Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ByteString -> CorsPolicy -> Response
respondPreFlight ByteString
origin CorsPolicy
policy
        ByteString
_ -> Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> CorsPolicy -> Response -> Response
injectCorsHeaders ByteString
origin CorsPolicy
policy

    respondPreFlight :: B.ByteString -> CorsPolicy -> Response
    respondPreFlight :: ByteString -> CorsPolicy -> Response
respondPreFlight ByteString
origin CorsPolicy
policy =
      [(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> [(ByteString, ByteString)]
forall {a} {b}. (IsString a, IsString b) => b -> [(a, b)]
mkPreFlightHeaders ByteString
requestedHeaders)
        (Response -> Response) -> Response -> Response
forall a b. (a -> b) -> a -> b
$ ByteString -> CorsPolicy -> Response -> Response
injectCorsHeaders ByteString
origin CorsPolicy
policy Response
emptyResponse

    emptyResponse :: Response
emptyResponse = Status -> [Header] -> ByteString -> Response
responseLBS Status
HTTP.status204 [] ByteString
""
    requestedHeaders :: ByteString
requestedHeaders =
      ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
""
        (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Access-Control-Request-Headers"
        ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req

    injectCorsHeaders :: B.ByteString -> CorsPolicy -> Response -> Response
    injectCorsHeaders :: ByteString -> CorsPolicy -> Response -> Response
injectCorsHeaders ByteString
origin CorsPolicy
policy = [(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> CorsPolicy -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin CorsPolicy
policy)

    mkPreFlightHeaders :: b -> [(a, b)]
mkPreFlightHeaders b
allowReqHdrs =
      [ (a
"Access-Control-Max-Age", b
"1728000"),
        (a
"Access-Control-Allow-Headers", b
allowReqHdrs),
        (a
"Content-Length", b
"0"),
        (a
"Content-Type", b
"text/plain charset=UTF-8")
      ]

    mkCorsHeaders :: ByteString -> CorsPolicy -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin CorsPolicy
policy =
      [ (ByteString
"Access-Control-Allow-Origin", ByteString
origin),
        (ByteString
"Access-Control-Allow-Credentials", ByteString
"true"),
        ( ByteString
"Access-Control-Allow-Methods",
          ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CorsPolicy -> [Text]
cpMethods CorsPolicy
policy
        ),
        -- console requires this header to access the cache headers as HGE and console
        -- are hosted on different domains in production
        ( ByteString
"Access-Control-Expose-Headers",
          ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Text]
cacheExposedHeaders
        )
      ]

    cacheExposedHeaders :: [Text]
cacheExposedHeaders = [Text
"X-Hasura-Query-Cache-Key", Text
"X-Hasura-Query-Family-Cache-Key", Text
"Warning"]
    setHeaders :: [(ByteString, ByteString)] -> Response -> Response
setHeaders [(ByteString, ByteString)]
hdrs = ([Header] -> [Header]) -> Response -> Response
mapResponseHeaders (\[Header]
h -> [(ByteString, ByteString)] -> [Header]
forall {b}. [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs [(ByteString, ByteString)]
hdrs [Header] -> [Header] -> [Header]
forall a. [a] -> [a] -> [a]
++ [Header]
h)
    mkRespHdrs :: [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs = ((ByteString, b) -> (HeaderName, b))
-> [(ByteString, b)] -> [(HeaderName, b)]
forall a b. (a -> b) -> [a] -> [b]
map (\(ByteString
k, b
v) -> (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk ByteString
k, b
v))