-- Potential improvement (TODO): change 'WS.RequestHead' to have IP address or even the
-- 'Wai.Request' object in it

-- | Taken from wai-websockets package and customized to get IP address on websocket connection
-- http://hackage.haskell.org/package/wai-websockets-3.0.1.2/docs/Network-Wai-Handler-WebSockets.html
module Network.Wai.Handler.WebSockets.Custom
  ( websocketsOr,
    websocketsApp,
    isWebSocketsReq,
    getRequestHead,
    runWebSockets,
  )
where

import Control.Exception (bracket, tryJust)
import Data.ByteString (ByteString)
import Data.ByteString.Char8 qualified as BC
import Data.ByteString.Lazy qualified as BL
import Data.CaseInsensitive qualified as CI
import Network.HTTP.Types (status500)
import Network.Wai.Extended qualified as Wai
import Network.WebSockets qualified as WS
import Network.WebSockets.Connection qualified as WS
import Network.WebSockets.Stream qualified as WS
import Prelude

--------------------------------------------------------------------------------

-- | Returns whether or not the given 'Wai.Request' is a WebSocket request.
isWebSocketsReq :: Wai.Request -> Bool
isWebSocketsReq :: Request -> Bool
isWebSocketsReq Request
req =
  (ByteString -> CI ByteString)
-> Maybe ByteString -> Maybe (CI ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk (CI ByteString -> [(CI ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup CI ByteString
"upgrade" ([(CI ByteString, ByteString)] -> Maybe ByteString)
-> [(CI ByteString, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(CI ByteString, ByteString)]
Wai.requestHeaders Request
req) Maybe (CI ByteString) -> Maybe (CI ByteString) -> Bool
forall a. Eq a => a -> a -> Bool
== CI ByteString -> Maybe (CI ByteString)
forall a. a -> Maybe a
Just CI ByteString
"websocket"

--------------------------------------------------------------------------------

-- | Upgrade a @websockets@ 'WS.ServerApp' to a @wai@ 'Wai.Application'. Uses
-- the given backup 'Wai.Application' to handle 'Wai.Request's that are not
-- WebSocket requests.
--
-- @
-- websocketsOr opts ws_app backup_app = \\req respond ->
--     __case__ 'websocketsApp' opts ws_app req __of__
--         'Nothing'  -> backup_app req send_response
--         'Just' res -> respond res
-- @
--
-- For example, below is an 'Wai.Application' that sends @"Hello, client!"@ to
-- each connected client.
--
-- @
-- app :: 'Wai.Application'
-- app = 'websocketsOr' 'WS.defaultConnectionOptions' wsApp backupApp
--   __where__
--     wsApp :: 'WS.ServerApp'
--     wsApp pending_conn = do
--         conn <- 'WS.acceptRequest' pending_conn
--         'WS.sendTextData' conn ("Hello, client!" :: 'Data.Text.Text')
--
--     backupApp :: 'Wai.Application'
--     backupApp _ respond = respond $ 'Wai.responseLBS' 'Network.HTTP.Types.status400' [] "Not a WebSocket request"
-- @
websocketsOr ::
  WS.ConnectionOptions ->
  (Wai.IpAddress -> WS.PendingConnection -> IO ()) ->
  Wai.Application ->
  Wai.Application
websocketsOr :: ConnectionOptions
-> (IpAddress -> PendingConnection -> IO ())
-> Application
-> Application
websocketsOr ConnectionOptions
opts IpAddress -> PendingConnection -> IO ()
app Application
backup Request
req Response -> IO ResponseReceived
sendResponse =
  case ConnectionOptions
-> (IpAddress -> PendingConnection -> IO ())
-> Request
-> Maybe Response
websocketsApp ConnectionOptions
opts IpAddress -> PendingConnection -> IO ()
app Request
req of
    Maybe Response
Nothing -> Application
backup Request
req Response -> IO ResponseReceived
sendResponse
    Just Response
res -> Response -> IO ResponseReceived
sendResponse Response
res

--------------------------------------------------------------------------------

-- | Handle a single @wai@ 'Wai.Request' with the given @websockets@
-- 'WS.ServerApp'. Returns 'Nothing' if the 'Wai.Request' is not a WebSocket
-- request, 'Just' otherwise.
--
-- Usually, 'websocketsOr' is more convenient.
websocketsApp ::
  WS.ConnectionOptions ->
  (Wai.IpAddress -> WS.PendingConnection -> IO ()) ->
  Wai.Request ->
  Maybe Wai.Response
websocketsApp :: ConnectionOptions
-> (IpAddress -> PendingConnection -> IO ())
-> Request
-> Maybe Response
websocketsApp ConnectionOptions
opts IpAddress -> PendingConnection -> IO ()
app Request
req
  | Request -> Bool
isWebSocketsReq Request
req =
    Response -> Maybe Response
forall a. a -> Maybe a
Just (Response -> Maybe Response) -> Response -> Maybe Response
forall a b. (a -> b) -> a -> b
$
      ((IO ByteString -> (ByteString -> IO ()) -> IO ())
 -> Response -> Response)
-> Response
-> (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response
forall a b c. (a -> b -> c) -> b -> a -> c
flip (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
Wai.responseRaw Response
backup ((IO ByteString -> (ByteString -> IO ()) -> IO ()) -> Response)
-> (IO ByteString -> (ByteString -> IO ()) -> IO ()) -> Response
forall a b. (a -> b) -> a -> b
$ \IO ByteString
src ByteString -> IO ()
sink ->
        ConnectionOptions
-> RequestHead
-> IpAddress
-> (IpAddress -> PendingConnection -> IO ())
-> IO ByteString
-> (ByteString -> IO ())
-> IO ()
forall a.
ConnectionOptions
-> RequestHead
-> IpAddress
-> (IpAddress -> PendingConnection -> IO a)
-> IO ByteString
-> (ByteString -> IO ())
-> IO a
runWebSockets ConnectionOptions
opts RequestHead
req' IpAddress
ipAddress IpAddress -> PendingConnection -> IO ()
app IO ByteString
src ByteString -> IO ()
sink
  | Bool
otherwise = Maybe Response
forall a. Maybe a
Nothing
  where
    (RequestHead
req', IpAddress
ipAddress) = Request -> (RequestHead, IpAddress)
getRequestHead Request
req
    backup :: Response
backup =
      Status -> [(CI ByteString, ByteString)] -> ByteString -> Response
Wai.responseLBS
        Status
status500
        [(CI ByteString
"Content-Type", ByteString
"text/plain")]
        ByteString
"The web application attempted to send a WebSockets response, but WebSockets are not supported by your WAI handler."

--------------------------------------------------------------------------------
getRequestHead :: Wai.Request -> (WS.RequestHead, Wai.IpAddress)
getRequestHead :: Request -> (RequestHead, IpAddress)
getRequestHead Request
req = (RequestHead
reqHead, Request -> IpAddress
Wai.getSourceFromFallback Request
req)
  where
    reqHead :: RequestHead
reqHead =
      ByteString -> [(CI ByteString, ByteString)] -> Bool -> RequestHead
WS.RequestHead
        (Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
`BC.append` Request -> ByteString
Wai.rawQueryString Request
req)
        (Request -> [(CI ByteString, ByteString)]
Wai.requestHeaders Request
req)
        (Request -> Bool
Wai.isSecure Request
req)

--------------------------------------------------------------------------------

-- | Internal function to run the WebSocket io-streams using the conduit library.
runWebSockets ::
  WS.ConnectionOptions ->
  WS.RequestHead ->
  Wai.IpAddress ->
  (Wai.IpAddress -> WS.PendingConnection -> IO a) ->
  IO ByteString ->
  (ByteString -> IO ()) ->
  IO a
runWebSockets :: ConnectionOptions
-> RequestHead
-> IpAddress
-> (IpAddress -> PendingConnection -> IO a)
-> IO ByteString
-> (ByteString -> IO ())
-> IO a
runWebSockets ConnectionOptions
opts RequestHead
req IpAddress
ipAddress IpAddress -> PendingConnection -> IO a
app IO ByteString
src ByteString -> IO ()
sink = IO Stream
-> (Stream -> IO (Either () ())) -> (Stream -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Stream
mkStream Stream -> IO (Either () ())
ensureClose (IpAddress -> PendingConnection -> IO a
app IpAddress
ipAddress (PendingConnection -> IO a)
-> (Stream -> PendingConnection) -> Stream -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> PendingConnection
pc)
  where
    ensureClose :: Stream -> IO (Either () ())
ensureClose = (ConnectionException -> Maybe ()) -> IO () -> IO (Either () ())
forall e b a.
Exception e =>
(e -> Maybe b) -> IO a -> IO (Either b a)
tryJust ConnectionException -> Maybe ()
onConnectionException (IO () -> IO (Either () ()))
-> (Stream -> IO ()) -> Stream -> IO (Either () ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stream -> IO ()
WS.close
    onConnectionException :: WS.ConnectionException -> Maybe ()
    onConnectionException :: ConnectionException -> Maybe ()
onConnectionException ConnectionException
WS.ConnectionClosed = () -> Maybe ()
forall a. a -> Maybe a
Just ()
    onConnectionException ConnectionException
_ = Maybe ()
forall a. Maybe a
Nothing
    mkStream :: IO Stream
mkStream =
      IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
WS.makeStream
        ( do
            ByteString
bs <- IO ByteString
src
            Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ if ByteString -> Bool
BC.null ByteString
bs then Maybe ByteString
forall a. Maybe a
Nothing else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs
        )
        ( \case
            Maybe ByteString
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just ByteString
bl -> (ByteString -> IO ()) -> [ByteString] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ByteString -> IO ()
sink (ByteString -> [ByteString]
BL.toChunks ByteString
bl)
        )

    pc :: Stream -> PendingConnection
pc Stream
stream =
      PendingConnection :: ConnectionOptions
-> RequestHead
-> (Connection -> IO ())
-> Stream
-> PendingConnection
WS.PendingConnection
        { pendingOptions :: ConnectionOptions
WS.pendingOptions = ConnectionOptions
opts,
          pendingRequest :: RequestHead
WS.pendingRequest = RequestHead
req,
          pendingOnAccept :: Connection -> IO ()
WS.pendingOnAccept = \Connection
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (),
          pendingStream :: Stream
WS.pendingStream = Stream
stream
        }