-- | MSSQL Connection Pooling
module Database.MSSQL.Pool
  ( -- * Types
    ConnectionString (..),
    ConnectionOptions (..),
    MSSQLPool (..),

    -- * Functions
    initMSSQLPool,
    drainMSSQLPool,
    withMSSQLPool,
    resizePool,
    getInUseConnections,
  )
where

import Autodocodec (HasCodec (codec), dimapCodec)
import Control.Exception.Lifted
import Control.Monad.Trans.Control
import Data.Aeson
import Data.Pool qualified as Pool
import Database.ODBC.SQLServer qualified as ODBC
import Hasura.Prelude (Generic, Text)
import Prelude

-- | ODBC connection string for MSSQL server
newtype ConnectionString = ConnectionString {ConnectionString -> Text
unConnectionString :: Text}
  deriving (Int -> ConnectionString -> ShowS
[ConnectionString] -> ShowS
ConnectionString -> String
(Int -> ConnectionString -> ShowS)
-> (ConnectionString -> String)
-> ([ConnectionString] -> ShowS)
-> Show ConnectionString
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionString -> ShowS
showsPrec :: Int -> ConnectionString -> ShowS
$cshow :: ConnectionString -> String
show :: ConnectionString -> String
$cshowList :: [ConnectionString] -> ShowS
showList :: [ConnectionString] -> ShowS
Show, ConnectionString -> ConnectionString -> Bool
(ConnectionString -> ConnectionString -> Bool)
-> (ConnectionString -> ConnectionString -> Bool)
-> Eq ConnectionString
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionString -> ConnectionString -> Bool
== :: ConnectionString -> ConnectionString -> Bool
$c/= :: ConnectionString -> ConnectionString -> Bool
/= :: ConnectionString -> ConnectionString -> Bool
Eq, [ConnectionString] -> Value
[ConnectionString] -> Encoding
ConnectionString -> Value
ConnectionString -> Encoding
(ConnectionString -> Value)
-> (ConnectionString -> Encoding)
-> ([ConnectionString] -> Value)
-> ([ConnectionString] -> Encoding)
-> ToJSON ConnectionString
forall a.
(a -> Value)
-> (a -> Encoding)
-> ([a] -> Value)
-> ([a] -> Encoding)
-> ToJSON a
$ctoJSON :: ConnectionString -> Value
toJSON :: ConnectionString -> Value
$ctoEncoding :: ConnectionString -> Encoding
toEncoding :: ConnectionString -> Encoding
$ctoJSONList :: [ConnectionString] -> Value
toJSONList :: [ConnectionString] -> Value
$ctoEncodingList :: [ConnectionString] -> Encoding
toEncodingList :: [ConnectionString] -> Encoding
ToJSON, Value -> Parser [ConnectionString]
Value -> Parser ConnectionString
(Value -> Parser ConnectionString)
-> (Value -> Parser [ConnectionString])
-> FromJSON ConnectionString
forall a.
(Value -> Parser a) -> (Value -> Parser [a]) -> FromJSON a
$cparseJSON :: Value -> Parser ConnectionString
parseJSON :: Value -> Parser ConnectionString
$cparseJSONList :: Value -> Parser [ConnectionString]
parseJSONList :: Value -> Parser [ConnectionString]
FromJSON, (forall x. ConnectionString -> Rep ConnectionString x)
-> (forall x. Rep ConnectionString x -> ConnectionString)
-> Generic ConnectionString
forall x. Rep ConnectionString x -> ConnectionString
forall x. ConnectionString -> Rep ConnectionString x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ConnectionString -> Rep ConnectionString x
from :: forall x. ConnectionString -> Rep ConnectionString x
$cto :: forall x. Rep ConnectionString x -> ConnectionString
to :: forall x. Rep ConnectionString x -> ConnectionString
Generic)

instance HasCodec ConnectionString where
  codec :: JSONCodec ConnectionString
codec = (Text -> ConnectionString)
-> (ConnectionString -> Text)
-> Codec Value Text Text
-> JSONCodec ConnectionString
forall oldOutput newOutput newInput oldInput context.
(oldOutput -> newOutput)
-> (newInput -> oldInput)
-> Codec context oldInput oldOutput
-> Codec context newInput newOutput
dimapCodec Text -> ConnectionString
ConnectionString ConnectionString -> Text
unConnectionString Codec Value Text Text
forall value. HasCodec value => JSONCodec value
codec

data ConnectionOptions
  = ConnectionOptions
      { ConnectionOptions -> Int
_coConnections :: Int,
        ConnectionOptions -> Int
_coStripes :: Int,
        ConnectionOptions -> Int
_coIdleTime :: Int
      }
  | ConnectionOptionsNoPool
  deriving (Int -> ConnectionOptions -> ShowS
[ConnectionOptions] -> ShowS
ConnectionOptions -> String
(Int -> ConnectionOptions -> ShowS)
-> (ConnectionOptions -> String)
-> ([ConnectionOptions] -> ShowS)
-> Show ConnectionOptions
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionOptions -> ShowS
showsPrec :: Int -> ConnectionOptions -> ShowS
$cshow :: ConnectionOptions -> String
show :: ConnectionOptions -> String
$cshowList :: [ConnectionOptions] -> ShowS
showList :: [ConnectionOptions] -> ShowS
Show, ConnectionOptions -> ConnectionOptions -> Bool
(ConnectionOptions -> ConnectionOptions -> Bool)
-> (ConnectionOptions -> ConnectionOptions -> Bool)
-> Eq ConnectionOptions
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ConnectionOptions -> ConnectionOptions -> Bool
== :: ConnectionOptions -> ConnectionOptions -> Bool
$c/= :: ConnectionOptions -> ConnectionOptions -> Bool
/= :: ConnectionOptions -> ConnectionOptions -> Bool
Eq)

-- | ODBC connection pool
data MSSQLPool
  = MSSQLPool (Pool.Pool ODBC.Connection)
  | MSSQLNoPool (IO ODBC.Connection)

-- | Initialize an MSSQL pool with given connection configuration
initMSSQLPool ::
  ConnectionString ->
  ConnectionOptions ->
  IO MSSQLPool
initMSSQLPool :: ConnectionString -> ConnectionOptions -> IO MSSQLPool
initMSSQLPool (ConnectionString Text
connString) ConnectionOptions
ConnectionOptionsNoPool = do
  MSSQLPool -> IO MSSQLPool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MSSQLPool -> IO MSSQLPool) -> MSSQLPool -> IO MSSQLPool
forall a b. (a -> b) -> a -> b
$ IO Connection -> MSSQLPool
MSSQLNoPool (Text -> IO Connection
forall (m :: * -> *). MonadIO m => Text -> m Connection
ODBC.connect Text
connString)
initMSSQLPool (ConnectionString Text
connString) ConnectionOptions {Int
_coConnections :: ConnectionOptions -> Int
_coStripes :: ConnectionOptions -> Int
_coIdleTime :: ConnectionOptions -> Int
_coConnections :: Int
_coStripes :: Int
_coIdleTime :: Int
..} = do
  Pool Connection -> MSSQLPool
MSSQLPool
    (Pool Connection -> MSSQLPool)
-> IO (Pool Connection) -> IO MSSQLPool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Connection
-> (Connection -> IO ())
-> Int
-> NominalDiffTime
-> Int
-> IO (Pool Connection)
forall a.
IO a
-> (a -> IO ()) -> Int -> NominalDiffTime -> Int -> IO (Pool a)
Pool.createPool
      (Text -> IO Connection
forall (m :: * -> *). MonadIO m => Text -> m Connection
ODBC.connect Text
connString)
      Connection -> IO ()
forall (m :: * -> *). MonadIO m => Connection -> m ()
ODBC.close
      Int
_coStripes
      (Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
_coIdleTime)
      Int
_coConnections

-- | Destroy all pool resources
drainMSSQLPool :: MSSQLPool -> IO ()
drainMSSQLPool :: MSSQLPool -> IO ()
drainMSSQLPool (MSSQLPool Pool Connection
pool) =
  Pool Connection -> IO ()
forall a. Pool a -> IO ()
Pool.destroyAllResources Pool Connection
pool
drainMSSQLPool MSSQLNoPool {} = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

withMSSQLPool ::
  (MonadBaseControl IO m) =>
  MSSQLPool ->
  (ODBC.Connection -> m a) ->
  m (Either ODBC.ODBCException a)
withMSSQLPool :: forall (m :: * -> *) a.
MonadBaseControl IO m =>
MSSQLPool -> (Connection -> m a) -> m (Either ODBCException a)
withMSSQLPool (MSSQLPool Pool Connection
pool) Connection -> m a
action = do
  m a -> m (Either ODBCException a)
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try (m a -> m (Either ODBCException a))
-> m a -> m (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ Pool Connection -> (Connection -> m a) -> m a
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
Pool.withResource Pool Connection
pool Connection -> m a
action
withMSSQLPool (MSSQLNoPool IO Connection
connect) Connection -> m a
action = do
  m a -> m (Either ODBCException a)
forall (m :: * -> *) e a.
(MonadBaseControl IO m, Exception e) =>
m a -> m (Either e a)
try (m a -> m (Either ODBCException a))
-> m a -> m (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ m Connection -> (Connection -> m ()) -> (Connection -> m a) -> m a
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket ((RunInBase m IO -> IO Connection) -> m Connection
forall a. (RunInBase m IO -> IO a) -> m a
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(RunInBase m b -> b a) -> m a
liftBaseWith (IO Connection -> (m Any -> IO (StM m Any)) -> IO Connection
forall a b. a -> b -> a
const IO Connection
connect)) (\Connection
conn -> (RunInBase m IO -> IO ()) -> m ()
forall a. (RunInBase m IO -> IO a) -> m a
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(RunInBase m b -> b a) -> m a
liftBaseWith (IO () -> (m Any -> IO (StM m Any)) -> IO ()
forall a b. a -> b -> a
const (Connection -> IO ()
forall (m :: * -> *). MonadIO m => Connection -> m ()
ODBC.close Connection
conn))) Connection -> m a
action

-- | Resize a pool
resizePool :: MSSQLPool -> Int -> IO ()
resizePool :: MSSQLPool -> Int -> IO ()
resizePool (MSSQLPool Pool Connection
pool) Int
resizeTo = do
  -- Resize the pool max resources
  Pool Connection -> Int -> IO ()
forall a. Pool a -> Int -> IO ()
Pool.resizePool Pool Connection
pool Int
resizeTo
  -- Trim pool by destroying excess resources, if any
  Pool Connection -> IO ()
forall a. Pool a -> IO ()
Pool.tryTrimPool Pool Connection
pool
resizePool (MSSQLNoPool {}) Int
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

getInUseConnections :: MSSQLPool -> IO Int
getInUseConnections :: MSSQLPool -> IO Int
getInUseConnections (MSSQLPool Pool Connection
pool) = Pool Connection -> IO Int
forall a. Pool a -> IO Int
Pool.getInUseResourceCount (Pool Connection -> IO Int) -> Pool Connection -> IO Int
forall a b. (a -> b) -> a -> b
$ Pool Connection
pool
getInUseConnections MSSQLNoPool {} = Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0