module Network.HTTP.Client.DynamicTlsPermissions
  ( dynamicTlsSettings,
  )
where

import Control.Exception.Safe (Exception, Typeable, impureThrow)
import Data.ByteString.Char8 qualified as BC
import Data.Default.Class qualified as HTTP
import Data.X509 qualified as HTTP
import Data.X509.CertificateStore qualified as HTTP
import Data.X509.Validation qualified as HTTP
import GHC.Exception (Exception (displayException))
import Hasura.Prelude
import Hasura.RQL.Types.Network (TlsAllow (TlsAllow), TlsPermission (SelfSigned))
import Network.Connection qualified as HTTP
import Network.TLS qualified as HTTP
import Network.TLS.Extra qualified as TLS
import System.X509 qualified as HTTP

newtype TlsServiceDefinitionError = TlsServiceDefinitionError
  { TlsServiceDefinitionError -> String
tlsServiceDefinitionError :: String
  }
  deriving (Int -> TlsServiceDefinitionError -> ShowS
[TlsServiceDefinitionError] -> ShowS
TlsServiceDefinitionError -> String
(Int -> TlsServiceDefinitionError -> ShowS)
-> (TlsServiceDefinitionError -> String)
-> ([TlsServiceDefinitionError] -> ShowS)
-> Show TlsServiceDefinitionError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TlsServiceDefinitionError] -> ShowS
$cshowList :: [TlsServiceDefinitionError] -> ShowS
show :: TlsServiceDefinitionError -> String
$cshow :: TlsServiceDefinitionError -> String
showsPrec :: Int -> TlsServiceDefinitionError -> ShowS
$cshowsPrec :: Int -> TlsServiceDefinitionError -> ShowS
Show, Typeable)

instance Exception TlsServiceDefinitionError where
  displayException :: TlsServiceDefinitionError -> String
displayException (TlsServiceDefinitionError String
msg) = String
"TlsServiceDefinitionError: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ShowS
forall a. Show a => a -> String
show String
msg

errorE :: String -> c
errorE :: String -> c
errorE = TlsServiceDefinitionError -> c
forall e a. Exception e => e -> a
impureThrow (TlsServiceDefinitionError -> c)
-> (String -> TlsServiceDefinitionError) -> String -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TlsServiceDefinitionError
TlsServiceDefinitionError

dynamicTlsSettings :: IO [TlsAllow] -> IO HTTP.TLSSettings
dynamicTlsSettings :: IO [TlsAllow] -> IO TLSSettings
dynamicTlsSettings IO [TlsAllow]
currentAllow = do
  CertificateStore
systemStore <- IO CertificateStore
HTTP.getSystemCertificateStore
  TLSSettings -> IO TLSSettings
forall (m :: * -> *) a. Monad m => a -> m a
return (CertificateStore -> TLSSettings
tlsSettingsComplex CertificateStore
systemStore)
  where
    tlsSettingsComplex :: HTTP.CertificateStore -> HTTP.TLSSettings
    tlsSettingsComplex :: CertificateStore -> TLSSettings
tlsSettingsComplex CertificateStore
systemStore = ClientParams -> TLSSettings
HTTP.TLSSettings (CertificateStore -> ClientParams
clientParams CertificateStore
systemStore)

    clientParams :: HTTP.CertificateStore -> HTTP.ClientParams
    clientParams :: CertificateStore -> ClientParams
clientParams CertificateStore
systemStore =
      (String -> ByteString -> ClientParams
HTTP.defaultParamsClient String
forall c. c
hostName ByteString
forall c. c
serviceIdBlob)
        { clientSupported :: Supported
HTTP.clientSupported = Supported
forall a. Default a => a
HTTP.def {supportedCiphers :: [Cipher]
HTTP.supportedCiphers = [Cipher]
TLS.ciphersuite_default}, -- supportedCiphers :: [Cipher]	Supported cipher methods. The default is empty, specify a suitable cipher list. ciphersuite_default is often a good choice.  Default: [] -- https://hackage.haskell.org/package/tls-1.5.5/docs/Network-TLS.html#t:Cipher
          clientShared :: Shared
HTTP.clientShared = Shared
forall a. Default a => a
HTTP.def {sharedCAStore :: CertificateStore
HTTP.sharedCAStore = CertificateStore
systemStore},
          clientHooks :: ClientHooks
HTTP.clientHooks =
            ClientHooks
forall a. Default a => a
HTTP.def
              { onServerCertificate :: OnServerCertificate
HTTP.onServerCertificate = OnServerCertificate
certValidation
              }
        }

    certValidation :: HTTP.CertificateStore -> HTTP.ValidationCache -> HTTP.ServiceID -> HTTP.CertificateChain -> IO [HTTP.FailedReason]
    certValidation :: OnServerCertificate
certValidation CertificateStore
certStore ValidationCache
validationCache ServiceID
sid CertificateChain
chain = do
      [FailedReason]
res <- ClientHooks -> OnServerCertificate
HTTP.onServerCertificate ClientHooks
forall a. Default a => a
HTTP.def CertificateStore
certStore ValidationCache
validationCache ServiceID
sid CertificateChain
chain
      [TlsAllow]
allowList <- IO [TlsAllow]
currentAllow
      if (TlsAllow -> Bool) -> [TlsAllow] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (ServiceID -> [FailedReason] -> TlsAllow -> Bool
allowed ServiceID
sid [FailedReason]
res) [TlsAllow]
allowList
        then [FailedReason] -> IO [FailedReason]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        else [FailedReason] -> IO [FailedReason]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [FailedReason]
res

    -- These always seem to be overwritten when a connection is established
    -- Should leave as errors in this case in order to validate this assumption.
    -- TODO: Is there any way to define this in terms of a pure exception?
    hostName :: c
hostName = String -> c
forall c. String -> c
errorE String
"hostname in HTTP client defaultParamsClient accessed - this should never happen"
    serviceIdBlob :: c
serviceIdBlob = String -> c
forall c. String -> c
errorE String
"serviceIdBlob in HTTP client defaultParamsClient accessed - this should never happen"

    -- Checks that:

    allowed :: (String, BC.ByteString) -> [HTTP.FailedReason] -> TlsAllow -> Bool
    allowed :: ServiceID -> [FailedReason] -> TlsAllow -> Bool
allowed (String
sHost, ByteString
sPort) [FailedReason]
res (TlsAllow String
aHost Maybe String
aPort Maybe [TlsPermission]
aPermit) =
      (String
sHost String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
aHost)
        Bool -> Bool -> Bool
&& (ByteString -> String
BC.unpack ByteString
sPort String -> Maybe String -> Bool
forall a. Eq a => a -> Maybe a -> Bool
==? Maybe String
aPort)
        Bool -> Bool -> Bool
&& (FailedReason -> Bool) -> [FailedReason] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\FailedReason
x -> (TlsPermission -> Bool) -> [TlsPermission] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (((FailedReason -> Bool) -> FailedReason -> Bool
forall a b. (a -> b) -> a -> b
$ FailedReason
x) ((FailedReason -> Bool) -> Bool)
-> (TlsPermission -> FailedReason -> Bool) -> TlsPermission -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TlsPermission -> FailedReason -> Bool
permitted) ([TlsPermission] -> Maybe [TlsPermission] -> [TlsPermission]
forall a. a -> Maybe a -> a
fromMaybe [TlsPermission
SelfSigned] Maybe [TlsPermission]
aPermit)) [FailedReason]
res
    -- TODO: Could clean up this check some more.

    -- Comments on failure reasons taken from https://hackage.haskell.org/package/x509-validation-1.4.7/docs/src/Data-X509-Validation.html
    -- The permitted function takes high-level concerns and translates then into certain permitted errors

    permitted :: TlsPermission -> FailedReason -> Bool
permitted TlsPermission
SelfSigned FailedReason
HTTP.SelfSigned = Bool
True -- Certificate is self signed
    permitted TlsPermission
SelfSigned (HTTP.NameMismatch String
_) = Bool
True -- Connection name and certificate do not match
    permitted TlsPermission
SelfSigned FailedReason
HTTP.LeafNotV3 = Bool
True -- Only authorized an X509.V3 certificate as leaf certificate.
    permitted TlsPermission
SelfSigned FailedReason
_ = Bool
False

    a
_ ==? :: a -> Maybe a -> Bool
==? Maybe a
Nothing = Bool
True
    a
a ==? Just a
a' = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
a'