-- | Restricted `ManagerSettings` for <https://haskell-lang.org/library/http-client>
-- -
-- - Portions from http-client-tls Copyright (c) 2013 Michael Snoyman
-- - Portions from http-client-restricted Copyright 2018 Joey Hess <id@joeyh.name>
-- -
-- - License: MIT
module Network.HTTP.Client.Restricted
  ( Decision (..),
    Restriction,
    mkRestrictedManagerSettings,
    ConnectionRestricted (..),
  )
where

import Control.Exception
import Data.Default
import Data.Maybe
import Data.Typeable
import Hasura.Prelude (onNothing)
import Network.BSD (getProtocolNumber)
import Network.Connection qualified as NC
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.Internal qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.Socket
import Prelude

data Decision = Allow | Deny

type Restriction = AddrInfo -> Decision

-- | Blocked requests raise this exception, wrapped as 'InternalException'.
data ConnectionRestricted = ConnectionRestricted
  { ConnectionRestricted -> String
crHostName :: String,
    ConnectionRestricted -> AddrInfo
crAddress :: AddrInfo
  }
  deriving (Int -> ConnectionRestricted -> ShowS
[ConnectionRestricted] -> ShowS
ConnectionRestricted -> String
(Int -> ConnectionRestricted -> ShowS)
-> (ConnectionRestricted -> String)
-> ([ConnectionRestricted] -> ShowS)
-> Show ConnectionRestricted
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionRestricted] -> ShowS
$cshowList :: [ConnectionRestricted] -> ShowS
show :: ConnectionRestricted -> String
$cshow :: ConnectionRestricted -> String
showsPrec :: Int -> ConnectionRestricted -> ShowS
$cshowsPrec :: Int -> ConnectionRestricted -> ShowS
Show, Typeable)

instance Exception ConnectionRestricted

-- | Adjusts a ManagerSettings to enforce a Restriction. The restriction
-- will be checked each time a Request is made, and for each redirect
-- followed.
--
-- This overrides the `managerRawConnection`
-- and `managerTlsConnection` with its own implementations that check
-- the Restriction. They should otherwise behave the same as the
-- ones provided by http-client-tls.
--
-- This function is not exported, because using it with a ManagerSettings
-- produced by something other than http-client-tls would result in
-- surprising behavior, since its connection methods would not be used.
restrictManagerSettings ::
  Maybe NC.ConnectionContext ->
  Maybe NC.TLSSettings ->
  Restriction ->
  HTTP.ManagerSettings ->
  HTTP.ManagerSettings
restrictManagerSettings :: Maybe ConnectionContext
-> Maybe TLSSettings
-> Restriction
-> ManagerSettings
-> ManagerSettings
restrictManagerSettings Maybe ConnectionContext
mcontext Maybe TLSSettings
mtls Restriction
cfg ManagerSettings
base =
  ManagerSettings
base
    { managerRawConnection :: IO (Maybe HostAddress -> String -> Int -> IO Connection)
HTTP.managerRawConnection = Restriction
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
restrictedRawConnection Restriction
cfg,
      managerTlsConnection :: IO (Maybe HostAddress -> String -> Int -> IO Connection)
HTTP.managerTlsConnection = Maybe ConnectionContext
-> Maybe TLSSettings
-> Restriction
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
restrictedTlsConnection Maybe ConnectionContext
mcontext Maybe TLSSettings
mtls Restriction
cfg,
      managerWrapException :: forall a. Request -> IO a -> IO a
HTTP.managerWrapException = ManagerSettings -> Request -> IO a -> IO a
forall a. ManagerSettings -> Request -> IO a -> IO a
wrapOurExceptions ManagerSettings
base
    }

-- | Makes a TLS-capable ManagerSettings with a Restriction applied to it.
--
-- The Restriction will be checked each time a Request is made, and for
-- each redirect followed.
--
-- Aside from checking the Restriction, it should behave the same as
-- `Network.HTTP.Client.TLS.mkManagerSettingsContext`
-- from http-client-tls.
--
-- > main = do
-- > 	manager <- newManager $ mkRestrictedManagerSettings myRestriction Nothing Nothing
-- >	request <- parseRequest "http://httpbin.org/get"
-- > 	response <- httpLbs request manager
-- > 	print $ responseBody response
--
-- See `mkManagerSettingsContext` for why
-- it can be useful to provide a `NC.ConnectionContext`.
--
-- Note that SOCKS is not supported.
mkRestrictedManagerSettings ::
  Restriction ->
  Maybe NC.ConnectionContext ->
  Maybe NC.TLSSettings ->
  HTTP.ManagerSettings
mkRestrictedManagerSettings :: Restriction
-> Maybe ConnectionContext -> Maybe TLSSettings -> ManagerSettings
mkRestrictedManagerSettings Restriction
cfg Maybe ConnectionContext
mcontext Maybe TLSSettings
mtls =
  Maybe ConnectionContext
-> Maybe TLSSettings
-> Restriction
-> ManagerSettings
-> ManagerSettings
restrictManagerSettings Maybe ConnectionContext
mcontext Maybe TLSSettings
mtls Restriction
cfg (ManagerSettings -> ManagerSettings)
-> ManagerSettings -> ManagerSettings
forall a b. (a -> b) -> a -> b
$
    Maybe ConnectionContext
-> TLSSettings -> Maybe SockSettings -> ManagerSettings
HTTP.mkManagerSettingsContext Maybe ConnectionContext
mcontext (TLSSettings -> Maybe TLSSettings -> TLSSettings
forall a. a -> Maybe a -> a
fromMaybe TLSSettings
forall a. Default a => a
def Maybe TLSSettings
mtls) Maybe SockSettings
forall a. Maybe a
Nothing

wrapOurExceptions :: HTTP.ManagerSettings -> HTTP.Request -> IO a -> IO a
wrapOurExceptions :: ManagerSettings -> Request -> IO a -> IO a
wrapOurExceptions ManagerSettings
base Request
req IO a
a =
  let wrapper :: SomeException -> SomeException
wrapper SomeException
se
        | Just (ConnectionRestricted
_ :: ConnectionRestricted) <- SomeException -> Maybe ConnectionRestricted
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se =
          HttpException -> SomeException
forall e. Exception e => e -> SomeException
toException (HttpException -> SomeException) -> HttpException -> SomeException
forall a b. (a -> b) -> a -> b
$
            Request -> HttpExceptionContent -> HttpException
HTTP.HttpExceptionRequest Request
req (HttpExceptionContent -> HttpException)
-> HttpExceptionContent -> HttpException
forall a b. (a -> b) -> a -> b
$
              SomeException -> HttpExceptionContent
HTTP.InternalException SomeException
se
        | Bool
otherwise = SomeException
se
   in ManagerSettings -> Request -> IO a -> IO a
ManagerSettings -> forall a. Request -> IO a -> IO a
HTTP.managerWrapException ManagerSettings
base Request
req ((SomeException -> IO a) -> IO a -> IO a
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO (SomeException -> IO a)
-> (SomeException -> SomeException) -> SomeException -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
wrapper) IO a
a)

restrictedRawConnection :: Restriction -> IO (Maybe HostAddress -> String -> Int -> IO HTTP.Connection)
restrictedRawConnection :: Restriction
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
restrictedRawConnection Restriction
cfg = Restriction
-> Maybe TLSSettings
-> Maybe ConnectionContext
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
getConnection Restriction
cfg Maybe TLSSettings
forall a. Maybe a
Nothing Maybe ConnectionContext
forall a. Maybe a
Nothing

restrictedTlsConnection :: Maybe NC.ConnectionContext -> Maybe NC.TLSSettings -> Restriction -> IO (Maybe HostAddress -> String -> Int -> IO HTTP.Connection)
restrictedTlsConnection :: Maybe ConnectionContext
-> Maybe TLSSettings
-> Restriction
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
restrictedTlsConnection Maybe ConnectionContext
mcontext Maybe TLSSettings
mtls Restriction
cfg =
  Restriction
-> Maybe TLSSettings
-> Maybe ConnectionContext
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
getConnection Restriction
cfg (TLSSettings -> Maybe TLSSettings
forall a. a -> Maybe a
Just (TLSSettings -> Maybe TLSSettings -> TLSSettings
forall a. a -> Maybe a -> a
fromMaybe TLSSettings
forall a. Default a => a
def Maybe TLSSettings
mtls)) Maybe ConnectionContext
mcontext

-- Based on Network.HTTP.Client.TLS.getTlsConnection.
--
-- Checks the Restriction
--
-- Does not support SOCKS.
getConnection ::
  Restriction ->
  Maybe NC.TLSSettings ->
  Maybe NC.ConnectionContext ->
  IO (Maybe HostAddress -> String -> Int -> IO HTTP.Connection)
getConnection :: Restriction
-> Maybe TLSSettings
-> Maybe ConnectionContext
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
getConnection Restriction
restriction Maybe TLSSettings
tls Maybe ConnectionContext
mcontext = do
  ConnectionContext
context <- Maybe ConnectionContext
-> IO ConnectionContext -> IO ConnectionContext
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
onNothing Maybe ConnectionContext
mcontext IO ConnectionContext
NC.initConnectionContext
  (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Maybe HostAddress -> String -> Int -> IO Connection)
 -> IO (Maybe HostAddress -> String -> Int -> IO Connection))
-> (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall a b. (a -> b) -> a -> b
$ \Maybe HostAddress
_hostAddress String
hostName Int
port ->
    IO Connection
-> (Connection -> IO ())
-> (Connection -> IO Connection)
-> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
      (ConnectionContext -> String -> Int -> IO Connection
go ConnectionContext
context String
hostName Int
port)
      Connection -> IO ()
NC.connectionClose
      Connection -> IO Connection
convertConnection
  where
    go :: ConnectionContext -> String -> Int -> IO Connection
go ConnectionContext
context String
hostName Int
port = do
      let connparams :: ConnectionParams
connparams =
            ConnectionParams :: String
-> PortNumber
-> Maybe TLSSettings
-> Maybe SockSettings
-> ConnectionParams
NC.ConnectionParams
              { connectionHostname :: String
NC.connectionHostname = String
host,
                connectionPort :: PortNumber
NC.connectionPort = Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
port,
                connectionUseSecure :: Maybe TLSSettings
NC.connectionUseSecure = Maybe TLSSettings
tls,
                connectionUseSocks :: Maybe SockSettings
NC.connectionUseSocks = Maybe SockSettings
forall a. Maybe a
Nothing -- unsupprted
              }
      ProtocolNumber
proto <- String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
      let serv :: String
serv = Int -> String
forall a. Show a => a -> String
show Int
port
      let hints :: AddrInfo
hints =
            AddrInfo
defaultHints
              { addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_ADDRCONFIG],
                addrProtocol :: ProtocolNumber
addrProtocol = ProtocolNumber
proto,
                addrSocketType :: SocketType
addrSocketType = SocketType
Stream
              }
      [AddrInfo]
addrs <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (String -> Maybe String
forall a. a -> Maybe a
Just String
host) (String -> Maybe String
forall a. a -> Maybe a
Just String
serv)
      IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Connection) -> IO Connection
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        ([IO Socket] -> IO Socket
firstSuccessful ([IO Socket] -> IO Socket) -> [IO Socket] -> IO Socket
forall a b. (a -> b) -> a -> b
$ (AddrInfo -> IO Socket) -> [AddrInfo] -> [IO Socket]
forall a b. (a -> b) -> [a] -> [b]
map AddrInfo -> IO Socket
tryToConnect [AddrInfo]
addrs)
        Socket -> IO ()
close
        (\Socket
sock -> ConnectionContext -> Socket -> ConnectionParams -> IO Connection
NC.connectFromSocket ConnectionContext
context Socket
sock ConnectionParams
connparams)
      where
        host :: String
host = ShowS
HTTP.strippedHostName String
hostName
        tryToConnect :: AddrInfo -> IO Socket
tryToConnect AddrInfo
addr = case Restriction
restriction AddrInfo
addr of
          Decision
Allow ->
            IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
              (Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr))
              Socket -> IO ()
close
              (\Socket
sock -> Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr) IO () -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)
          Decision
Deny -> ConnectionRestricted -> IO Socket
forall e a. Exception e => e -> IO a
throwIO (ConnectionRestricted -> IO Socket)
-> ConnectionRestricted -> IO Socket
forall a b. (a -> b) -> a -> b
$ String -> AddrInfo -> ConnectionRestricted
ConnectionRestricted String
host AddrInfo
addr
        firstSuccessful :: [IO Socket] -> IO Socket
firstSuccessful [] = HostNotResolved -> IO Socket
forall e a. Exception e => e -> IO a
throwIO (HostNotResolved -> IO Socket) -> HostNotResolved -> IO Socket
forall a b. (a -> b) -> a -> b
$ String -> HostNotResolved
NC.HostNotResolved String
host
        firstSuccessful (IO Socket
a : [IO Socket]
as) =
          IO Socket
a IO Socket -> (IOException -> IO Socket) -> IO Socket
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(IOException
e :: IOException) ->
            case [IO Socket]
as of
              [] -> IOException -> IO Socket
forall e a. Exception e => e -> IO a
throwIO IOException
e
              [IO Socket]
_ -> [IO Socket] -> IO Socket
firstSuccessful [IO Socket]
as

-- Copied from Network.HTTP.Client.TLS, unfortunately not exported.
convertConnection :: NC.Connection -> IO HTTP.Connection
convertConnection :: Connection -> IO Connection
convertConnection Connection
conn =
  IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
HTTP.makeConnection
    (Connection -> IO ByteString
NC.connectionGetChunk Connection
conn)
    (Connection -> ByteString -> IO ()
NC.connectionPut Connection
conn)
    -- Closing an SSL connection gracefully involves writing/reading
    -- on the socket.  But when this is called the socket might be
    -- already closed, and we get a @ResourceVanished@.
    (Connection -> IO ()
NC.connectionClose Connection
conn IO () -> (IOException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`Control.Exception.catch` \(IOException
_ :: IOException) -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())