module Network.HTTP.Client.Blocklisting
  ( block,
    Blocklist (..),
  )
where

import Hasura.Prelude
import Net.IPv4 qualified as IPv4
import Net.IPv6 qualified as IPv6
import Network.HTTP.Client.Restricted qualified as Restricted
import Network.Socket

data Blocklist = Blocklist
  { Blocklist -> [IPv4Range]
ipv4Blocklist :: [IPv4.IPv4Range],
    Blocklist -> [IPv6Range]
ipv6Blocklist :: [IPv6.IPv6Range]
  }
  deriving (Int -> Blocklist -> ShowS
[Blocklist] -> ShowS
Blocklist -> String
(Int -> Blocklist -> ShowS)
-> (Blocklist -> String)
-> ([Blocklist] -> ShowS)
-> Show Blocklist
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Blocklist] -> ShowS
$cshowList :: [Blocklist] -> ShowS
show :: Blocklist -> String
$cshow :: Blocklist -> String
showsPrec :: Int -> Blocklist -> ShowS
$cshowsPrec :: Int -> Blocklist -> ShowS
Show, (forall x. Blocklist -> Rep Blocklist x)
-> (forall x. Rep Blocklist x -> Blocklist) -> Generic Blocklist
forall x. Rep Blocklist x -> Blocklist
forall x. Blocklist -> Rep Blocklist x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Blocklist x -> Blocklist
$cfrom :: forall x. Blocklist -> Rep Blocklist x
Generic)

instance Semigroup Blocklist where
  Blocklist [IPv4Range]
ipv4Lst1 [IPv6Range]
ipv6Lst1 <> :: Blocklist -> Blocklist -> Blocklist
<> Blocklist [IPv4Range]
ipv4Lst2 [IPv6Range]
ipv6Lst2 = [IPv4Range] -> [IPv6Range] -> Blocklist
Blocklist ([IPv4Range]
ipv4Lst1 [IPv4Range] -> [IPv4Range] -> [IPv4Range]
forall a. [a] -> [a] -> [a]
++ [IPv4Range]
ipv4Lst2) ([IPv6Range]
ipv6Lst1 [IPv6Range] -> [IPv6Range] -> [IPv6Range]
forall a. [a] -> [a] -> [a]
++ [IPv6Range]
ipv6Lst2)

instance Monoid Blocklist where
  mempty :: Blocklist
mempty = [IPv4Range] -> [IPv6Range] -> Blocklist
Blocklist [] []

-- | Determine whether the given address is blocked by the given blocklist.
-- NOTE: Only restricts IPv4 and IPv6 addresses. Other address families are
-- not restricted.
block :: Blocklist -> AddrInfo -> Restricted.Decision
block :: Blocklist -> AddrInfo -> Decision
block Blocklist
blocklist AddrInfo
addr =
  if SockAddr -> Bool
sockAddrInBlocklist (SockAddr -> Bool) -> SockAddr -> Bool
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
    then Decision
Restricted.Deny
    else Decision
Restricted.Allow
  where
    sockAddrInBlocklist :: SockAddr -> Bool
sockAddrInBlocklist = \case
      (SockAddrInet PortNumber
_ HostAddress
hostAddr) -> (IPv4Range -> Bool) -> [IPv4Range] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IPv4 -> IPv4Range -> Bool
IPv4.member (IPv4 -> IPv4Range -> Bool) -> IPv4 -> IPv4Range -> Bool
forall a b. (a -> b) -> a -> b
$ HostAddress -> IPv4
ipv4Addr HostAddress
hostAddr) (Blocklist -> [IPv4Range]
ipv4Blocklist Blocklist
blocklist)
      (SockAddrInet6 PortNumber
_ HostAddress
_ HostAddress6
hostAddr6 HostAddress
_) -> (IPv6Range -> Bool) -> [IPv6Range] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (IPv6 -> IPv6Range -> Bool
IPv6.member (IPv6 -> IPv6Range -> Bool) -> IPv6 -> IPv6Range -> Bool
forall a b. (a -> b) -> a -> b
$ HostAddress6 -> IPv6
ipv6Addr HostAddress6
hostAddr6) (Blocklist -> [IPv6Range]
ipv6Blocklist Blocklist
blocklist)
      SockAddr
_ -> Bool
False
    ipv4Addr :: HostAddress -> IPv4
ipv4Addr = (Word8, Word8, Word8, Word8) -> IPv4
IPv4.fromTupleOctets ((Word8, Word8, Word8, Word8) -> IPv4)
-> (HostAddress -> (Word8, Word8, Word8, Word8))
-> HostAddress
-> IPv4
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostAddress -> (Word8, Word8, Word8, Word8)
hostAddressToTuple
    ipv6Addr :: HostAddress6 -> IPv6
ipv6Addr = HostAddress6 -> IPv6
IPv6.fromTupleWord32s