module Hasura.Server.Middleware
  ( corsMiddleware,
  )
where

import Control.Applicative
import Data.ByteString qualified as B
import Data.CaseInsensitive qualified as CI
import Data.Text.Encoding qualified as TE
import Hasura.Prelude
import Hasura.Server.Cors
import Hasura.Server.Utils
import Network.HTTP.Types qualified as HTTP
import Network.Wai

corsMiddleware :: CorsPolicy -> Middleware
corsMiddleware :: CorsPolicy -> Middleware
corsMiddleware CorsPolicy
policy Application
app Request
req Response -> IO ResponseReceived
sendResp = do
  let origin :: Maybe ByteString
origin = HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Origin" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
  IO ResponseReceived
-> (ByteString -> IO ResponseReceived)
-> Maybe ByteString
-> IO ResponseReceived
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Application
app Request
req Response -> IO ResponseReceived
sendResp) ByteString -> IO ResponseReceived
handleCors Maybe ByteString
origin
  where
    handleCors :: ByteString -> IO ResponseReceived
handleCors ByteString
origin = case CorsPolicy -> CorsConfig
cpConfig CorsPolicy
policy of
      CCDisabled Bool
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResp
      CorsConfig
CCAllowAll -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
      CCAllowedOrigins Domains
ds
        -- if the origin is in our cors domains, send cors headers
        | ByteString -> Text
bsToTxt ByteString
origin Text -> HashSet Text -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Domains -> HashSet Text
dmFqdns Domains
ds -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
        -- if current origin is part of wildcard domain list, send cors
        | Domains -> Text -> Bool
inWildcardList Domains
ds (ByteString -> Text
bsToTxt ByteString
origin) -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
        -- otherwise don't send cors headers
        | Bool
otherwise -> Application
app Request
req Response -> IO ResponseReceived
sendResp

    sendCors :: B.ByteString -> IO ResponseReceived
    sendCors :: ByteString -> IO ResponseReceived
sendCors ByteString
origin =
      case Request -> ByteString
requestMethod Request
req of
        ByteString
"OPTIONS" -> Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ByteString -> Response
respondPreFlight ByteString
origin
        ByteString
_ -> Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Response -> Response
injectCorsHeaders ByteString
origin

    respondPreFlight :: B.ByteString -> Response
    respondPreFlight :: ByteString -> Response
respondPreFlight ByteString
origin =
      [(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> [(ByteString, ByteString)]
forall a b. (IsString a, IsString b) => b -> [(a, b)]
mkPreFlightHeaders ByteString
requestedHeaders) (Response -> Response) -> Response -> Response
forall a b. (a -> b) -> a -> b
$
        ByteString -> Response -> Response
injectCorsHeaders ByteString
origin Response
emptyResponse

    emptyResponse :: Response
emptyResponse = Status -> [Header] -> ByteString -> Response
responseLBS Status
HTTP.status204 [] ByteString
""
    requestedHeaders :: ByteString
requestedHeaders =
      ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
        HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Access-Control-Request-Headers" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
          Request -> [Header]
requestHeaders Request
req

    injectCorsHeaders :: B.ByteString -> Response -> Response
    injectCorsHeaders :: ByteString -> Response -> Response
injectCorsHeaders ByteString
origin = [(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin)

    mkPreFlightHeaders :: b -> [(a, b)]
mkPreFlightHeaders b
allowReqHdrs =
      [ (a
"Access-Control-Max-Age", b
"1728000"),
        (a
"Access-Control-Allow-Headers", b
allowReqHdrs),
        (a
"Content-Length", b
"0"),
        (a
"Content-Type", b
"text/plain charset=UTF-8")
      ]

    mkCorsHeaders :: ByteString -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin =
      [ (ByteString
"Access-Control-Allow-Origin", ByteString
origin),
        (ByteString
"Access-Control-Allow-Credentials", ByteString
"true"),
        ( ByteString
"Access-Control-Allow-Methods",
          ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CorsPolicy -> [Text]
cpMethods CorsPolicy
policy
        )
      ]

    setHeaders :: [(ByteString, ByteString)] -> Response -> Response
setHeaders [(ByteString, ByteString)]
hdrs = ([Header] -> [Header]) -> Response -> Response
mapResponseHeaders (\[Header]
h -> [(ByteString, ByteString)] -> [Header]
forall b. [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs [(ByteString, ByteString)]
hdrs [Header] -> [Header] -> [Header]
forall a. [a] -> [a] -> [a]
++ [Header]
h)
    mkRespHdrs :: [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs = ((ByteString, b) -> (HeaderName, b))
-> [(ByteString, b)] -> [(HeaderName, b)]
forall a b. (a -> b) -> [a] -> [b]
map (\(ByteString
k, b
v) -> (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk ByteString
k, b
v))