module Hasura.Server.Middleware
( corsMiddleware,
)
where
import Control.Applicative
import Data.ByteString qualified as B
import Data.CaseInsensitive qualified as CI
import Data.Text.Encoding qualified as TE
import Hasura.Prelude
import Hasura.Server.Cors
import Hasura.Server.Utils
import Network.HTTP.Types qualified as HTTP
import Network.Wai
corsMiddleware :: CorsPolicy -> Middleware
corsMiddleware :: CorsPolicy -> Middleware
corsMiddleware CorsPolicy
policy Application
app Request
req Response -> IO ResponseReceived
sendResp = do
let origin :: Maybe ByteString
origin = HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Origin" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
IO ResponseReceived
-> (ByteString -> IO ResponseReceived)
-> Maybe ByteString
-> IO ResponseReceived
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Application
app Request
req Response -> IO ResponseReceived
sendResp) ByteString -> IO ResponseReceived
handleCors Maybe ByteString
origin
where
handleCors :: ByteString -> IO ResponseReceived
handleCors ByteString
origin = case CorsPolicy -> CorsConfig
cpConfig CorsPolicy
policy of
CCDisabled Bool
_ -> Application
app Request
req Response -> IO ResponseReceived
sendResp
CorsConfig
CCAllowAll -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
CCAllowedOrigins Domains
ds
| ByteString -> Text
bsToTxt ByteString
origin Text -> HashSet Text -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Domains -> HashSet Text
dmFqdns Domains
ds -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
| Domains -> Text -> Bool
inWildcardList Domains
ds (ByteString -> Text
bsToTxt ByteString
origin) -> ByteString -> IO ResponseReceived
sendCors ByteString
origin
| Bool
otherwise -> Application
app Request
req Response -> IO ResponseReceived
sendResp
sendCors :: B.ByteString -> IO ResponseReceived
sendCors :: ByteString -> IO ResponseReceived
sendCors ByteString
origin =
case Request -> ByteString
requestMethod Request
req of
ByteString
"OPTIONS" -> Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ByteString -> Response
respondPreFlight ByteString
origin
ByteString
_ -> Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Response -> IO ResponseReceived
sendResp (Response -> IO ResponseReceived)
-> (Response -> Response) -> Response -> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Response -> Response
injectCorsHeaders ByteString
origin
respondPreFlight :: B.ByteString -> Response
respondPreFlight :: ByteString -> Response
respondPreFlight ByteString
origin =
[(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> [(ByteString, ByteString)]
forall a b. (IsString a, IsString b) => b -> [(a, b)]
mkPreFlightHeaders ByteString
requestedHeaders) (Response -> Response) -> Response -> Response
forall a b. (a -> b) -> a -> b
$
ByteString -> Response -> Response
injectCorsHeaders ByteString
origin Response
emptyResponse
emptyResponse :: Response
emptyResponse = Status -> [Header] -> ByteString -> Response
responseLBS Status
HTTP.status204 [] ByteString
""
requestedHeaders :: ByteString
requestedHeaders =
ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
HeaderName -> [Header] -> Maybe ByteString
getRequestHeader HeaderName
"Access-Control-Request-Headers" ([Header] -> Maybe ByteString) -> [Header] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
Request -> [Header]
requestHeaders Request
req
injectCorsHeaders :: B.ByteString -> Response -> Response
injectCorsHeaders :: ByteString -> Response -> Response
injectCorsHeaders ByteString
origin = [(ByteString, ByteString)] -> Response -> Response
setHeaders (ByteString -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin)
mkPreFlightHeaders :: b -> [(a, b)]
mkPreFlightHeaders b
allowReqHdrs =
[ (a
"Access-Control-Max-Age", b
"1728000"),
(a
"Access-Control-Allow-Headers", b
allowReqHdrs),
(a
"Content-Length", b
"0"),
(a
"Content-Type", b
"text/plain charset=UTF-8")
]
mkCorsHeaders :: ByteString -> [(ByteString, ByteString)]
mkCorsHeaders ByteString
origin =
[ (ByteString
"Access-Control-Allow-Origin", ByteString
origin),
(ByteString
"Access-Control-Allow-Credentials", ByteString
"true"),
( ByteString
"Access-Control-Allow-Methods",
ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"," ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
TE.encodeUtf8 (Text -> ByteString) -> [Text] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CorsPolicy -> [Text]
cpMethods CorsPolicy
policy
)
]
setHeaders :: [(ByteString, ByteString)] -> Response -> Response
setHeaders [(ByteString, ByteString)]
hdrs = ([Header] -> [Header]) -> Response -> Response
mapResponseHeaders (\[Header]
h -> [(ByteString, ByteString)] -> [Header]
forall b. [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs [(ByteString, ByteString)]
hdrs [Header] -> [Header] -> [Header]
forall a. [a] -> [a] -> [a]
++ [Header]
h)
mkRespHdrs :: [(ByteString, b)] -> [(HeaderName, b)]
mkRespHdrs = ((ByteString, b) -> (HeaderName, b))
-> [(ByteString, b)] -> [(HeaderName, b)]
forall a b. (a -> b) -> [a] -> [b]
map (\(ByteString
k, b
v) -> (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
CI.mk ByteString
k, b
v))