module Network.Wai.Extended
  ( module Wai,
    getSourceFromFallback,
    IpAddress (..),
    showIPAddress,
  )
where

import Data.Bits (shift, (.&.))
import Data.ByteString.Char8 (ByteString)
import Data.ByteString.Char8 qualified as BS
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Text.Encoding qualified as TE
import Data.Text.Encoding.Error qualified as TE
import Data.Word (Word32)
import Network.Socket (SockAddr (..))
import Network.Wai as Wai
import System.ByteOrder (ByteOrder (..), byteOrder)
import Text.Printf (printf)
import Prelude

-- | IP Address related code
newtype IpAddress = IpAddress {IpAddress -> ByteString
unIpAddress :: ByteString}
  deriving (Int -> IpAddress -> ShowS
[IpAddress] -> ShowS
IpAddress -> String
(Int -> IpAddress -> ShowS)
-> (IpAddress -> String)
-> ([IpAddress] -> ShowS)
-> Show IpAddress
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IpAddress] -> ShowS
$cshowList :: [IpAddress] -> ShowS
show :: IpAddress -> String
$cshow :: IpAddress -> String
showsPrec :: Int -> IpAddress -> ShowS
$cshowsPrec :: Int -> IpAddress -> ShowS
Show, IpAddress -> IpAddress -> Bool
(IpAddress -> IpAddress -> Bool)
-> (IpAddress -> IpAddress -> Bool) -> Eq IpAddress
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IpAddress -> IpAddress -> Bool
$c/= :: IpAddress -> IpAddress -> Bool
== :: IpAddress -> IpAddress -> Bool
$c== :: IpAddress -> IpAddress -> Bool
Eq)

showIPAddress :: IpAddress -> Text
showIPAddress :: IpAddress -> Text
showIPAddress = OnDecodeError -> ByteString -> Text
TE.decodeUtf8With OnDecodeError
TE.lenientDecode (ByteString -> Text)
-> (IpAddress -> ByteString) -> IpAddress -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IpAddress -> ByteString
unIpAddress

getSourceFromSocket :: Wai.Request -> IpAddress
getSourceFromSocket :: Request -> IpAddress
getSourceFromSocket = ByteString -> IpAddress
IpAddress (ByteString -> IpAddress)
-> (Request -> ByteString) -> Request -> IpAddress
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
BS.pack (String -> ByteString)
-> (Request -> String) -> Request -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SockAddr -> String
showSockAddr (SockAddr -> String) -> (Request -> SockAddr) -> Request -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> SockAddr
Wai.remoteHost

getSourceFromFallback :: Wai.Request -> IpAddress
getSourceFromFallback :: Request -> IpAddress
getSourceFromFallback Request
req = IpAddress -> Maybe IpAddress -> IpAddress
forall a. a -> Maybe a -> a
fromMaybe (Request -> IpAddress
getSourceFromSocket Request
req) (Maybe IpAddress -> IpAddress) -> Maybe IpAddress -> IpAddress
forall a b. (a -> b) -> a -> b
$ Request -> Maybe IpAddress
getSource Request
req

getSource :: Wai.Request -> Maybe IpAddress
getSource :: Request -> Maybe IpAddress
getSource Request
req = ByteString -> IpAddress
IpAddress (ByteString -> IpAddress) -> Maybe ByteString -> Maybe IpAddress
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ByteString
addr
  where
    maddr :: Maybe (HeaderName, ByteString)
maddr = ((HeaderName, ByteString) -> Bool)
-> [(HeaderName, ByteString)] -> Maybe (HeaderName, ByteString)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(HeaderName, ByteString)
x -> (HeaderName, ByteString) -> HeaderName
forall a b. (a, b) -> a
fst (HeaderName, ByteString)
x HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName
"x-real-ip", HeaderName
"x-forwarded-for"]) [(HeaderName, ByteString)]
hdrs
    addr :: Maybe ByteString
addr = ((HeaderName, ByteString) -> ByteString)
-> Maybe (HeaderName, ByteString) -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HeaderName, ByteString) -> ByteString
forall a b. (a, b) -> b
snd Maybe (HeaderName, ByteString)
maddr
    hdrs :: [(HeaderName, ByteString)]
hdrs = Request -> [(HeaderName, ByteString)]
Wai.requestHeaders Request
req

-- |  A type for IP address in numeric string representation.
type NumericAddress = String

showIPv4 :: Word32 -> Bool -> NumericAddress
showIPv4 :: Word32 -> Bool -> String
showIPv4 Word32
w32 Bool
little
  | Bool
little = Word32 -> String
forall a. Show a => a -> String
show Word32
b1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b3 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b4
  | Bool
otherwise = Word32 -> String
forall a. Show a => a -> String
show Word32
b4 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b3 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"." String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
b1
  where
    t1 :: Word32
t1 = Word32
w32
    t2 :: Word32
t2 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shift Word32
t1 (-Int
8)
    t3 :: Word32
t3 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shift Word32
t2 (-Int
8)
    t4 :: Word32
t4 = Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
shift Word32
t3 (-Int
8)
    b1 :: Word32
b1 = Word32
t1 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x000000ff
    b2 :: Word32
b2 = Word32
t2 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x000000ff
    b3 :: Word32
b3 = Word32
t3 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x000000ff
    b4 :: Word32
b4 = Word32
t4 Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x000000ff

showIPv6 :: (Word32, Word32, Word32, Word32) -> String
showIPv6 :: (Word32, Word32, Word32, Word32) -> String
showIPv6 (Word32
w1, Word32
w2, Word32
w3, Word32
w4) =
  String
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> Word32
-> String
forall r. PrintfType r => String -> r
printf String
"%x:%x:%x:%x:%x:%x:%x:%x" Word32
s1 Word32
s2 Word32
s3 Word32
s4 Word32
s5 Word32
s6 Word32
s7 Word32
s8
  where
    (Word32
s1, Word32
s2) = Word32 -> (Word32, Word32)
forall b. (Bits b, Num b) => b -> (b, b)
split16 Word32
w1
    (Word32
s3, Word32
s4) = Word32 -> (Word32, Word32)
forall b. (Bits b, Num b) => b -> (b, b)
split16 Word32
w2
    (Word32
s5, Word32
s6) = Word32 -> (Word32, Word32)
forall b. (Bits b, Num b) => b -> (b, b)
split16 Word32
w3
    (Word32
s7, Word32
s8) = Word32 -> (Word32, Word32)
forall b. (Bits b, Num b) => b -> (b, b)
split16 Word32
w4
    split16 :: b -> (b, b)
split16 b
w = (b
h1, b
h2)
      where
        h1 :: b
h1 = b -> Int -> b
forall a. Bits a => a -> Int -> a
shift b
w (-Int
16) b -> b -> b
forall a. Bits a => a -> a -> a
.&. b
0x0000ffff
        h2 :: b
h2 = b
w b -> b -> b
forall a. Bits a => a -> a -> a
.&. b
0x0000ffff

-- | Convert 'SockAddr' to 'NumericAddress'. If the address is
--   IPv4-embedded IPv6 address, the IPv4 is extracted.
showSockAddr :: SockAddr -> NumericAddress
-- HostAddr is network byte order.
showSockAddr :: SockAddr -> String
showSockAddr (SockAddrInet PortNumber
_ Word32
addr4) = Word32 -> Bool -> String
showIPv4 Word32
addr4 (ByteOrder
byteOrder ByteOrder -> ByteOrder -> Bool
forall a. Eq a => a -> a -> Bool
== ByteOrder
LittleEndian)
-- HostAddr6 is host byte order.
showSockAddr (SockAddrInet6 PortNumber
_ Word32
_ (Word32
0, Word32
0, Word32
0x0000ffff, Word32
addr4) Word32
_) = Word32 -> Bool -> String
showIPv4 Word32
addr4 Bool
False
showSockAddr (SockAddrInet6 PortNumber
_ Word32
_ (Word32
0, Word32
0, Word32
0, Word32
1) Word32
_) = String
"::1"
showSockAddr (SockAddrInet6 PortNumber
_ Word32
_ (Word32, Word32, Word32, Word32)
addr6 Word32
_) = (Word32, Word32, Word32, Word32) -> String
showIPv6 (Word32, Word32, Word32, Word32)
addr6
showSockAddr SockAddr
_ = String
"unknownSocket"