module Database.MSSQL.Transaction
  ( TxET (..),
    MSSQLTxError (..),
    TxIsolation (..),
    TxT,
    TxE,
    runTx,
    runTxE,
    unitQuery,
    unitQueryE,
    singleRowQuery,
    singleRowQueryE,
    multiRowQuery,
    multiRowQueryE,
    forJsonQueryE,
    buildGenericQueryTxE,
    withTxET,
  )
where

import Autodocodec (HasCodec (codec), bimapCodec, textCodec, (<?>))
import Autodocodec.Aeson qualified as AC
import Control.Exception (try)
import Control.Monad.Morph (MFunctor (hoist))
import Control.Monad.Trans.Control (MonadBaseControl)
import Data.Aeson qualified as J
import Data.Text qualified as T
import Database.MSSQL.Pool
import Database.ODBC.SQLServer (FromRow)
import Database.ODBC.SQLServer qualified as ODBC
import Hasura.Prelude

-- | The transaction command to run, parameterised over:
-- e - the exception type (usually 'MSSQLTxError')
-- m - some Monad, (usually some 'MonadIO')
-- a - the successful result type
newtype TxET e m a = TxET
  { forall e (m :: * -> *) a.
TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler :: ReaderT ODBC.Connection (ExceptT e m) a
  }
  deriving
    ( (forall a b. (a -> b) -> TxET e m a -> TxET e m b)
-> (forall a b. a -> TxET e m b -> TxET e m a)
-> Functor (TxET e m)
forall a b. a -> TxET e m b -> TxET e m a
forall a b. (a -> b) -> TxET e m a -> TxET e m b
forall e (m :: * -> *) a b.
Functor m =>
a -> TxET e m b -> TxET e m a
forall e (m :: * -> *) a b.
Functor m =>
(a -> b) -> TxET e m a -> TxET e m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall e (m :: * -> *) a b.
Functor m =>
(a -> b) -> TxET e m a -> TxET e m b
fmap :: forall a b. (a -> b) -> TxET e m a -> TxET e m b
$c<$ :: forall e (m :: * -> *) a b.
Functor m =>
a -> TxET e m b -> TxET e m a
<$ :: forall a b. a -> TxET e m b -> TxET e m a
Functor,
      Functor (TxET e m)
Functor (TxET e m)
-> (forall a. a -> TxET e m a)
-> (forall a b. TxET e m (a -> b) -> TxET e m a -> TxET e m b)
-> (forall a b c.
    (a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m b)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m a)
-> Applicative (TxET e m)
forall a. a -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m b
forall a b. TxET e m (a -> b) -> TxET e m a -> TxET e m b
forall a b c.
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
forall {e} {m :: * -> *}. Monad m => Functor (TxET e m)
forall e (m :: * -> *) a. Monad m => a -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
forall e (m :: * -> *) a b.
Monad m =>
TxET e m (a -> b) -> TxET e m a -> TxET e m b
forall e (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall e (m :: * -> *) a. Monad m => a -> TxET e m a
pure :: forall a. a -> TxET e m a
$c<*> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m (a -> b) -> TxET e m a -> TxET e m b
<*> :: forall a b. TxET e m (a -> b) -> TxET e m a -> TxET e m b
$cliftA2 :: forall e (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
liftA2 :: forall a b c.
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
$c*> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
*> :: forall a b. TxET e m a -> TxET e m b -> TxET e m b
$c<* :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m a
<* :: forall a b. TxET e m a -> TxET e m b -> TxET e m a
Applicative,
      Applicative (TxET e m)
Applicative (TxET e m)
-> (forall a b. TxET e m a -> (a -> TxET e m b) -> TxET e m b)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m b)
-> (forall a. a -> TxET e m a)
-> Monad (TxET e m)
forall a. a -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m b
forall a b. TxET e m a -> (a -> TxET e m b) -> TxET e m b
forall e (m :: * -> *). Monad m => Applicative (TxET e m)
forall e (m :: * -> *) a. Monad m => a -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> (a -> TxET e m b) -> TxET e m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> (a -> TxET e m b) -> TxET e m b
>>= :: forall a b. TxET e m a -> (a -> TxET e m b) -> TxET e m b
$c>> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
>> :: forall a b. TxET e m a -> TxET e m b -> TxET e m b
$creturn :: forall e (m :: * -> *) a. Monad m => a -> TxET e m a
return :: forall a. a -> TxET e m a
Monad,
      MonadError e,
      Monad (TxET e m)
Monad (TxET e m)
-> (forall a. IO a -> TxET e m a) -> MonadIO (TxET e m)
forall a. IO a -> TxET e m a
forall {e} {m :: * -> *}. MonadIO m => Monad (TxET e m)
forall e (m :: * -> *) a. MonadIO m => IO a -> TxET e m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall e (m :: * -> *) a. MonadIO m => IO a -> TxET e m a
liftIO :: forall a. IO a -> TxET e m a
MonadIO,
      MonadReader ODBC.Connection,
      Monad (TxET e m)
Monad (TxET e m)
-> (forall a. (a -> TxET e m a) -> TxET e m a)
-> MonadFix (TxET e m)
forall a. (a -> TxET e m a) -> TxET e m a
forall {e} {m :: * -> *}. MonadFix m => Monad (TxET e m)
forall e (m :: * -> *) a.
MonadFix m =>
(a -> TxET e m a) -> TxET e m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
$cmfix :: forall e (m :: * -> *) a.
MonadFix m =>
(a -> TxET e m a) -> TxET e m a
mfix :: forall a. (a -> TxET e m a) -> TxET e m a
MonadFix
    )

instance MFunctor (TxET e) where
  hoist :: forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> TxET e m b -> TxET e n b
hoist forall a. m a -> n a
f = ReaderT Connection (ExceptT e n) b -> TxET e n b
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e n) b -> TxET e n b)
-> (TxET e m b -> ReaderT Connection (ExceptT e n) b)
-> TxET e m b
-> TxET e n b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. ExceptT e m a -> ExceptT e n a)
-> ReaderT Connection (ExceptT e m) b
-> ReaderT Connection (ExceptT e n) b
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a)
-> ReaderT Connection m b -> ReaderT Connection n b
hoist ((forall a. m a -> n a) -> ExceptT e m a -> ExceptT e n a
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a) -> ExceptT e m b -> ExceptT e n b
hoist m a -> n a
forall a. m a -> n a
f) (ReaderT Connection (ExceptT e m) b
 -> ReaderT Connection (ExceptT e n) b)
-> (TxET e m b -> ReaderT Connection (ExceptT e m) b)
-> TxET e m b
-> ReaderT Connection (ExceptT e n) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TxET e m b -> ReaderT Connection (ExceptT e m) b
forall e (m :: * -> *) a.
TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler

instance MonadTrans (TxET e) where
  lift :: forall (m :: * -> *) a. Monad m => m a -> TxET e m a
lift = ReaderT Connection (ExceptT e m) a -> TxET e m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e m) a -> TxET e m a)
-> (m a -> ReaderT Connection (ExceptT e m) a) -> m a -> TxET e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT e m a -> ReaderT Connection (ExceptT e m) a
forall (m :: * -> *) a. Monad m => m a -> ReaderT Connection m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT e m a -> ReaderT Connection (ExceptT e m) a)
-> (m a -> ExceptT e m a)
-> m a
-> ReaderT Connection (ExceptT e m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ExceptT e m a
forall (m :: * -> *) a. Monad m => m a -> ExceptT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Error type generally used in 'TxET'.
data MSSQLTxError
  = MSSQLQueryError !ODBC.Query !ODBC.ODBCException
  | MSSQLConnError !ODBC.ODBCException
  | MSSQLInternal !Text
  deriving (MSSQLTxError -> MSSQLTxError -> Bool
(MSSQLTxError -> MSSQLTxError -> Bool)
-> (MSSQLTxError -> MSSQLTxError -> Bool) -> Eq MSSQLTxError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MSSQLTxError -> MSSQLTxError -> Bool
== :: MSSQLTxError -> MSSQLTxError -> Bool
$c/= :: MSSQLTxError -> MSSQLTxError -> Bool
/= :: MSSQLTxError -> MSSQLTxError -> Bool
Eq, Int -> MSSQLTxError -> ShowS
[MSSQLTxError] -> ShowS
MSSQLTxError -> String
(Int -> MSSQLTxError -> ShowS)
-> (MSSQLTxError -> String)
-> ([MSSQLTxError] -> ShowS)
-> Show MSSQLTxError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MSSQLTxError -> ShowS
showsPrec :: Int -> MSSQLTxError -> ShowS
$cshow :: MSSQLTxError -> String
show :: MSSQLTxError -> String
$cshowList :: [MSSQLTxError] -> ShowS
showList :: [MSSQLTxError] -> ShowS
Show)

type TxE e a = TxET e IO a

-- | The transaction command to run, returning an MSSQLTxError or the result.
type TxT m a = TxET MSSQLTxError m a

-- | Run a command on the given connection wrapped in a transaction.
--
-- See 'runTxE' if you need to map the error type as well.
runTx ::
  (MonadIO m, MonadBaseControl IO m) =>
  TxIsolation ->
  TxT m a ->
  MSSQLPool ->
  ExceptT MSSQLTxError m a
runTx :: forall (m :: * -> *) a.
(MonadIO m, MonadBaseControl IO m) =>
TxIsolation -> TxT m a -> MSSQLPool -> ExceptT MSSQLTxError m a
runTx = (MSSQLTxError -> MSSQLTxError)
-> TxIsolation
-> TxET MSSQLTxError m a
-> MSSQLPool
-> ExceptT MSSQLTxError m a
forall (m :: * -> *) e a.
(MonadIO m, MonadBaseControl IO m) =>
(MSSQLTxError -> e)
-> TxIsolation -> TxET e m a -> MSSQLPool -> ExceptT e m a
runTxE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Run a command on the given connection wrapped in a transaction.
runTxE ::
  (MonadIO m, MonadBaseControl IO m) =>
  (MSSQLTxError -> e) ->
  TxIsolation ->
  TxET e m a ->
  MSSQLPool ->
  ExceptT e m a
runTxE :: forall (m :: * -> *) e a.
(MonadIO m, MonadBaseControl IO m) =>
(MSSQLTxError -> e)
-> TxIsolation -> TxET e m a -> MSSQLPool -> ExceptT e m a
runTxE MSSQLTxError -> e
ef TxIsolation
txIsolation TxET e m a
tx MSSQLPool
pool = do
  MSSQLPool
-> (Connection -> ExceptT e m a)
-> ExceptT e m (Either ODBCException a)
forall (m :: * -> *) a.
MonadBaseControl IO m =>
MSSQLPool -> (Connection -> m a) -> m (Either ODBCException a)
withMSSQLPool MSSQLPool
pool ((MSSQLTxError -> e)
-> TxIsolation
-> (Connection -> ExceptT e m a)
-> Connection
-> ExceptT e m a
forall e a (m :: * -> *).
MonadIO m =>
(MSSQLTxError -> e)
-> TxIsolation
-> (Connection -> ExceptT e m a)
-> Connection
-> ExceptT e m a
asTransaction MSSQLTxError -> e
ef TxIsolation
txIsolation (Connection -> TxET e m a -> ExceptT e m a
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
`execTx` TxET e m a
tx))
    ExceptT e m (Either ODBCException a)
-> (Either ODBCException a -> ExceptT e m a) -> ExceptT e m a
forall a b. ExceptT e m a -> (a -> ExceptT e m b) -> ExceptT e m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either e a -> ExceptT e m a
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> ExceptT e m a
hoistEither
    (Either e a -> ExceptT e m a)
-> (Either ODBCException a -> Either e a)
-> Either ODBCException a
-> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ODBCException -> e) -> Either ODBCException a -> Either e a
forall e1 e2 a. (e1 -> e2) -> Either e1 a -> Either e2 a
mapLeft (MSSQLTxError -> e
ef (MSSQLTxError -> e)
-> (ODBCException -> MSSQLTxError) -> ODBCException -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ODBCException -> MSSQLTxError
MSSQLConnError)

-- | Useful for building transactions which return no data.
--
-- @
-- insertId :: TxT m ()
-- insertId = unitQuery "INSERT INTO some_table VALUES (1, \"hello\")"
-- @
--
-- See 'unitQueryE' if you need to map the error type as well.
unitQuery :: (MonadIO m) => ODBC.Query -> TxT m ()
unitQuery :: forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxET MSSQLTxError m ()
forall (m :: * -> *) e.
MonadIO m =>
(MSSQLTxError -> e) -> Query -> TxET e m ()
unitQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building transactions which return no data.
unitQueryE :: (MonadIO m) => (MSSQLTxError -> e) -> ODBC.Query -> TxET e m ()
unitQueryE :: forall (m :: * -> *) e.
MonadIO m =>
(MSSQLTxError -> e) -> Query -> TxET e m ()
unitQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String ()) -> Query -> TxET e m ()
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String ()
emptyResult
  where
    emptyResult :: MSSQLResult -> Either String ()
    emptyResult :: MSSQLResult -> Either String ()
emptyResult (MSSQLResult []) = () -> Either String ()
forall a b. b -> Either a b
Right ()
    emptyResult (MSSQLResult [[Value]]
_) = String -> Either String ()
forall a b. a -> Either a b
Left String
"expecting no data for ()"

-- | Useful for building query transactions which return a single one row.
--
-- @
-- returnOne :: TxT m Int
-- returnOne = singleRowQuery "SELECT 1"
-- @
--
-- See 'singleRowQueryE' if you need to map the error type as well.
singleRowQuery :: forall a m. (MonadIO m, FromRow a) => ODBC.Query -> TxT m a
singleRowQuery :: forall a (m :: * -> *). (MonadIO m, FromRow a) => Query -> TxT m a
singleRowQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxET MSSQLTxError m a
forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m a
singleRowQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building query transactions which return a single one row.
singleRowQueryE ::
  forall m a e.
  (MonadIO m, FromRow a) =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m a
singleRowQueryE :: forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m a
singleRowQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String a
singleRowResult
  where
    singleRowResult :: MSSQLResult -> Either String a
    singleRowResult :: MSSQLResult -> Either String a
singleRowResult (MSSQLResult [[Value]
row]) = [Value] -> Either String a
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [Value]
row
    singleRowResult (MSSQLResult [[Value]]
_) = String -> Either String a
forall a b. a -> Either a b
Left String
"expecting single row"

-- | MSSQL splits up results that have a @SELECT .. FOR JSON@ at the top-level
-- into multiple rows with a single column, see
-- https://docs.microsoft.com/en-us/sql/relational-databases/json/format-query-results-as-json-with-for-json-sql-server?view=sql-server-ver15#output-of-the-for-json-clause
--
-- This function simply concatenates each single-column row into one long 'Text' string.
forJsonQueryE ::
  forall m e.
  (MonadIO m) =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m Text
forJsonQueryE :: forall (m :: * -> *) e.
MonadIO m =>
(MSSQLTxError -> e) -> Query -> TxET e m Text
forJsonQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String Text) -> Query -> TxET e m Text
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String Text
concatRowResult
  where
    concatRowResult :: MSSQLResult -> Either String Text
    concatRowResult :: MSSQLResult -> Either String Text
concatRowResult (MSSQLResult []) = Text -> Either String Text
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
forall a. Monoid a => a
mempty
    concatRowResult (MSSQLResult rows :: [[Value]]
rows@([Value]
r1 : [[Value]]
_)) | [Value] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> Either String [Text] -> Either String Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Value] -> Either String Text)
-> [[Value]] -> Either String [Text]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM [Value] -> Either String Text
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [[Value]]
rows
    concatRowResult (MSSQLResult ([Value]
r1 : [[Value]]
_)) = String -> Either String Text
forall a b. a -> Either a b
Left (String -> Either String Text) -> String -> Either String Text
forall a b. (a -> b) -> a -> b
$ String
"forJsonQueryE: Expected single-column results, but got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Value] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
r1) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" columns"

-- | Useful for building query transactions which return multiple rows.
--
-- @
-- selectIds :: TxT m [Int]
-- selectIds = multiRowQuery "SELECT id FROM author"
-- @
--
-- See 'multiRowQueryE' if you need to map the error type as well.
multiRowQuery :: forall a m. (MonadIO m, FromRow a) => ODBC.Query -> TxT m [a]
multiRowQuery :: forall a (m :: * -> *).
(MonadIO m, FromRow a) =>
Query -> TxT m [a]
multiRowQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxET MSSQLTxError m [a]
forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m [a]
multiRowQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building query transactions which return multiple rows.
multiRowQueryE ::
  forall m a e.
  (MonadIO m, FromRow a) =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m [a]
multiRowQueryE :: forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m [a]
multiRowQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String [a]) -> Query -> TxET e m [a]
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String [a]
multiRowResult
  where
    multiRowResult :: MSSQLResult -> Either String [a]
    multiRowResult :: MSSQLResult -> Either String [a]
multiRowResult (MSSQLResult [[Value]]
rows) = ([Value] -> Either String a) -> [[Value]] -> Either String [a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse [Value] -> Either String a
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [[Value]]
rows

-- | Build a generic transaction out of an IO action.
buildGenericQueryTxE ::
  (MonadIO m) =>
  -- | map 'MSSQLTxError' to some other type
  (MSSQLTxError -> e) ->
  -- | query to run
  query ->
  -- | how to map a query to a 'ODBC.Query'
  (query -> ODBC.Query) ->
  -- | run the query on a provided 'ODBC.Connection'
  (ODBC.Connection -> query -> IO a) ->
  TxET e m a
buildGenericQueryTxE :: forall (m :: * -> *) e query a.
MonadIO m =>
(MSSQLTxError -> e)
-> query
-> (query -> Query)
-> (Connection -> query -> IO a)
-> TxET e m a
buildGenericQueryTxE MSSQLTxError -> e
errorF query
query query -> Query
convertQ Connection -> query -> IO a
runQuery =
  ReaderT Connection (ExceptT e m) a -> TxET e m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e m) a -> TxET e m a)
-> ReaderT Connection (ExceptT e m) a -> TxET e m a
forall a b. (a -> b) -> a -> b
$ (Connection -> ExceptT e m a) -> ReaderT Connection (ExceptT e m) a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((Connection -> ExceptT e m a)
 -> ReaderT Connection (ExceptT e m) a)
-> (Connection -> ExceptT e m a)
-> ReaderT Connection (ExceptT e m) a
forall a b. (a -> b) -> a -> b
$ (MSSQLTxError -> e) -> ExceptT MSSQLTxError m a -> ExceptT e m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
errorF (ExceptT MSSQLTxError m a -> ExceptT e m a)
-> (Connection -> ExceptT MSSQLTxError m a)
-> Connection
-> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
forall (m :: * -> *) a query.
MonadIO m =>
query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
execQuery query
query query -> Query
convertQ ((query -> IO a) -> ExceptT MSSQLTxError m a)
-> (Connection -> query -> IO a)
-> Connection
-> ExceptT MSSQLTxError m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> query -> IO a
runQuery

-- | Map the error type for a 'TxET'.
withTxET :: (Monad m) => (e1 -> e2) -> TxET e1 m a -> TxET e2 m a
withTxET :: forall (m :: * -> *) e1 e2 a.
Monad m =>
(e1 -> e2) -> TxET e1 m a -> TxET e2 m a
withTxET e1 -> e2
f (TxET ReaderT Connection (ExceptT e1 m) a
m) = ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a)
-> ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a
forall a b. (a -> b) -> a -> b
$ (forall a. ExceptT e1 m a -> ExceptT e2 m a)
-> ReaderT Connection (ExceptT e1 m) a
-> ReaderT Connection (ExceptT e2 m) a
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
forall (m :: * -> *) (n :: * -> *) b.
Monad m =>
(forall a. m a -> n a)
-> ReaderT Connection m b -> ReaderT Connection n b
hoist ((e1 -> e2) -> ExceptT e1 m a -> ExceptT e2 m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT e1 -> e2
f) ReaderT Connection (ExceptT e1 m) a
m

-- | A successful result from a query is a list of rows where each row contains
-- list of column values
newtype MSSQLResult = MSSQLResult [[ODBC.Value]]
  deriving (MSSQLResult -> MSSQLResult -> Bool
(MSSQLResult -> MSSQLResult -> Bool)
-> (MSSQLResult -> MSSQLResult -> Bool) -> Eq MSSQLResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MSSQLResult -> MSSQLResult -> Bool
== :: MSSQLResult -> MSSQLResult -> Bool
$c/= :: MSSQLResult -> MSSQLResult -> Bool
/= :: MSSQLResult -> MSSQLResult -> Bool
Eq, Int -> MSSQLResult -> ShowS
[MSSQLResult] -> ShowS
MSSQLResult -> String
(Int -> MSSQLResult -> ShowS)
-> (MSSQLResult -> String)
-> ([MSSQLResult] -> ShowS)
-> Show MSSQLResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MSSQLResult -> ShowS
showsPrec :: Int -> MSSQLResult -> ShowS
$cshow :: MSSQLResult -> String
show :: MSSQLResult -> String
$cshowList :: [MSSQLResult] -> ShowS
showList :: [MSSQLResult] -> ShowS
Show)

-- | Packs a query, along with result and error converters into a 'TxET'.
--
-- Used by 'unitQueryE', 'singleRowQueryE', and 'multiRowQueryE'.
rawQueryE ::
  (MonadIO m) =>
  -- | Error modifier
  (MSSQLTxError -> e) ->
  -- | Result modifier with a failure
  (MSSQLResult -> Either String a) ->
  -- | Query to run
  ODBC.Query ->
  TxET e m a
rawQueryE :: forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String a
rf Query
q = do
  [[Value]]
rows <- (MSSQLTxError -> e)
-> Query
-> (Query -> Query)
-> (Connection -> Query -> IO [[Value]])
-> TxET e m [[Value]]
forall (m :: * -> *) e query a.
MonadIO m =>
(MSSQLTxError -> e)
-> query
-> (query -> Query)
-> (Connection -> query -> IO a)
-> TxET e m a
buildGenericQueryTxE MSSQLTxError -> e
ef Query
q Query -> Query
forall a. a -> a
id Connection -> Query -> IO [[Value]]
forall (m :: * -> *) row.
(MonadIO m, FromRow row) =>
Connection -> Query -> m [row]
ODBC.query
  Either e a -> TxET e m a
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither
    (Either e a -> TxET e m a) -> Either e a -> TxET e m a
forall a b. (a -> b) -> a -> b
$ (String -> e) -> Either String a -> Either e a
forall e1 e2 a. (e1 -> e2) -> Either e1 a -> Either e2 a
mapLeft (MSSQLTxError -> e
ef (MSSQLTxError -> e) -> (String -> MSSQLTxError) -> String -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Query -> ODBCException -> MSSQLTxError
MSSQLQueryError Query
q (ODBCException -> MSSQLTxError)
-> (String -> ODBCException) -> String -> MSSQLTxError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ODBCException
ODBC.DataRetrievalError)
    (Either String a -> Either e a) -> Either String a -> Either e a
forall a b. (a -> b) -> a -> b
$ MSSQLResult -> Either String a
rf ([[Value]] -> MSSQLResult
MSSQLResult [[Value]]
rows)

-- | Combinator for abstracting over the query type and ensuring we catch exceptions.
--
-- Used by 'buildGenericQueryTxE'.
execQuery ::
  forall m a query.
  (MonadIO m) =>
  query ->
  (query -> ODBC.Query) ->
  (query -> IO a) ->
  ExceptT MSSQLTxError m a
execQuery :: forall (m :: * -> *) a query.
MonadIO m =>
query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
execQuery query
query query -> Query
toODBCQuery query -> IO a
runQuery = do
  Either ODBCException a
result :: Either ODBC.ODBCException a <- IO (Either ODBCException a)
-> ExceptT MSSQLTxError m (Either ODBCException a)
forall a. IO a -> ExceptT MSSQLTxError m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either ODBCException a)
 -> ExceptT MSSQLTxError m (Either ODBCException a))
-> IO (Either ODBCException a)
-> ExceptT MSSQLTxError m (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ IO a -> IO (Either ODBCException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO (Either ODBCException a))
-> IO a -> IO (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ query -> IO a
runQuery query
query
  (ODBCException -> MSSQLTxError)
-> ExceptT ODBCException m a -> ExceptT MSSQLTxError m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT (Query -> ODBCException -> MSSQLTxError
MSSQLQueryError (Query -> ODBCException -> MSSQLTxError)
-> Query -> ODBCException -> MSSQLTxError
forall a b. (a -> b) -> a -> b
$ query -> Query
toODBCQuery query
query) (ExceptT ODBCException m a -> ExceptT MSSQLTxError m a)
-> ExceptT ODBCException m a -> ExceptT MSSQLTxError m a
forall a b. (a -> b) -> a -> b
$ Either ODBCException a -> ExceptT ODBCException m a
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> ExceptT e m a
hoistEither Either ODBCException a
result

-- | Run a 'TxET' with the given connection.
--
-- Used by 'runTxE' and 'asTransaction'.
execTx :: ODBC.Connection -> TxET e m a -> ExceptT e m a
execTx :: forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET e m a
tx = ReaderT Connection (ExceptT e m) a -> Connection -> ExceptT e m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (TxET e m a -> ReaderT Connection (ExceptT e m) a
forall e (m :: * -> *) a.
TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler TxET e m a
tx) Connection
conn
{-# INLINE execTx #-}

-- | The transaction state of the current connection
data TransactionState
  = -- | Has an active transaction.
    TSActive
  | -- | Has no active transaction.
    TSNoActive
  | -- | An error occurred that caused the transaction to be uncommittable.
    -- We cannot commit or rollback to a savepoint; we can only do a full
    -- rollback of the transaction.
    TSUncommittable

-- | <https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql>
data TxIsolation
  = ReadUncommitted
  | ReadCommitted
  | RepeatableRead
  | Snapshot
  | Serializable
  deriving (TxIsolation -> TxIsolation -> Bool
(TxIsolation -> TxIsolation -> Bool)
-> (TxIsolation -> TxIsolation -> Bool) -> Eq TxIsolation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TxIsolation -> TxIsolation -> Bool
== :: TxIsolation -> TxIsolation -> Bool
$c/= :: TxIsolation -> TxIsolation -> Bool
/= :: TxIsolation -> TxIsolation -> Bool
Eq, (forall x. TxIsolation -> Rep TxIsolation x)
-> (forall x. Rep TxIsolation x -> TxIsolation)
-> Generic TxIsolation
forall x. Rep TxIsolation x -> TxIsolation
forall x. TxIsolation -> Rep TxIsolation x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. TxIsolation -> Rep TxIsolation x
from :: forall x. TxIsolation -> Rep TxIsolation x
$cto :: forall x. Rep TxIsolation x -> TxIsolation
to :: forall x. Rep TxIsolation x -> TxIsolation
Generic)

instance Show TxIsolation where
  show :: TxIsolation -> String
show = \case
    TxIsolation
ReadUncommitted -> String
"READ UNCOMMITTED"
    TxIsolation
ReadCommitted -> String
"READ COMMITTED"
    TxIsolation
RepeatableRead -> String
"REPEATABLE READ"
    TxIsolation
Snapshot -> String
"SNAPSHOT"
    TxIsolation
Serializable -> String
"SERIALIZABLE"

instance Hashable TxIsolation

instance NFData TxIsolation

instance HasCodec TxIsolation where
  codec :: JSONCodec TxIsolation
codec =
    (Text -> Either String TxIsolation)
-> (TxIsolation -> Text)
-> Codec Value Text Text
-> JSONCodec TxIsolation
forall oldOutput newOutput newInput oldInput context.
(oldOutput -> Either String newOutput)
-> (newInput -> oldInput)
-> Codec context oldInput oldOutput
-> Codec context newInput newOutput
bimapCodec
      Text -> Either String TxIsolation
decode
      TxIsolation -> Text
encode
      Codec Value Text Text
textCodec
      JSONCodec TxIsolation -> Text -> JSONCodec TxIsolation
forall input output.
ValueCodec input output -> Text -> ValueCodec input output
<?> Text
"Isolation level"
    where
      decode :: Text -> Either String TxIsolation
      decode :: Text -> Either String TxIsolation
decode = \case
        Text
"read-uncommitted" -> TxIsolation -> Either String TxIsolation
forall a b. b -> Either a b
Right TxIsolation
ReadUncommitted
        Text
"read-committed" -> TxIsolation -> Either String TxIsolation
forall a b. b -> Either a b
Right TxIsolation
ReadCommitted
        Text
"repeatable-read" -> TxIsolation -> Either String TxIsolation
forall a b. b -> Either a b
Right TxIsolation
RepeatableRead
        Text
"snapshot" -> TxIsolation -> Either String TxIsolation
forall a b. b -> Either a b
Right TxIsolation
Snapshot
        Text
"serializable" -> TxIsolation -> Either String TxIsolation
forall a b. b -> Either a b
Right TxIsolation
Serializable
        Text
_ ->
          String -> Either String TxIsolation
forall a b. a -> Either a b
Left
            (String -> Either String TxIsolation)
-> String -> Either String TxIsolation
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack
            (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
"Unexpected options for isolation_level. Expected "
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'read-uncommited' | 'read-committed' | 'repeatable-read' | 'snapshot' | 'serializable'"
      encode :: TxIsolation -> Text
      encode :: TxIsolation -> Text
encode = \case
        TxIsolation
ReadUncommitted -> Text
"read-uncommitted"
        TxIsolation
ReadCommitted -> Text
"read-committed"
        TxIsolation
RepeatableRead -> Text
"repeatable-read"
        TxIsolation
Snapshot -> Text
"snapshot"
        TxIsolation
Serializable -> Text
"serializable"

instance J.ToJSON TxIsolation where
  toJSON :: TxIsolation -> Value
toJSON = TxIsolation -> Value
forall a. HasCodec a => a -> Value
AC.toJSONViaCodec
  toEncoding :: TxIsolation -> Encoding
toEncoding = TxIsolation -> Encoding
forall a. HasCodec a => a -> Encoding
AC.toEncodingViaCodec

instance J.FromJSON TxIsolation where
  parseJSON :: Value -> Parser TxIsolation
parseJSON = Value -> Parser TxIsolation
forall a. HasCodec a => Value -> Parser a
AC.parseJSONViaCodec

-- | Wraps an action in a transaction. Rolls back on errors.
asTransaction ::
  forall e a m.
  (MonadIO m) =>
  (MSSQLTxError -> e) ->
  TxIsolation ->
  (ODBC.Connection -> ExceptT e m a) ->
  ODBC.Connection ->
  ExceptT e m a
asTransaction :: forall e a (m :: * -> *).
MonadIO m =>
(MSSQLTxError -> e)
-> TxIsolation
-> (Connection -> ExceptT e m a)
-> Connection
-> ExceptT e m a
asTransaction MSSQLTxError -> e
ef TxIsolation
txIsolation Connection -> ExceptT e m a
action Connection
conn = do
  -- Begin the transaction. If there is an error, do not rollback.
  (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn (TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ())
-> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall a b. (a -> b) -> a -> b
$ TxIsolation -> TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxIsolation -> TxT m ()
setTxIsoLevelTx TxIsolation
txIsolation TxET MSSQLTxError m ()
-> TxET MSSQLTxError m () -> TxET MSSQLTxError m ()
forall a b.
TxET MSSQLTxError m a
-> TxET MSSQLTxError m b -> TxET MSSQLTxError m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
beginTx
  -- Run the transaction and commit. If there is an error, rollback.
  (ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a)
-> (e -> ExceptT e m a) -> ExceptT e m a -> ExceptT e m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
forall a. ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError e -> ExceptT e m a
forall b. e -> ExceptT e m b
rollbackAndThrow do
    a
result <- Connection -> ExceptT e m a
action Connection
conn
    -- After running the transaction, set the transaction isolation level
    -- to the default isolation level i.e. Read Committed
    (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn (TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ())
-> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall a b. (a -> b) -> a -> b
$ TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
commitTx TxET MSSQLTxError m ()
-> TxET MSSQLTxError m () -> TxET MSSQLTxError m ()
forall a b.
TxET MSSQLTxError m a
-> TxET MSSQLTxError m b -> TxET MSSQLTxError m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TxIsolation -> TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxIsolation -> TxT m ()
setTxIsoLevelTx TxIsolation
ReadCommitted
    a -> ExceptT e m a
forall a. a -> ExceptT e m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
result
  where
    -- Rollback and throw error.
    rollbackAndThrow :: e -> ExceptT e m b
    rollbackAndThrow :: forall b. e -> ExceptT e m b
rollbackAndThrow e
err = do
      (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
rollbackTx
      e -> ExceptT e m b
forall b. e -> ExceptT e m b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
err

beginTx :: (MonadIO m) => TxT m ()
beginTx :: forall (m :: * -> *). MonadIO m => TxT m ()
beginTx = Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"BEGIN TRANSACTION"

setTxIsoLevelTx :: (MonadIO m) => TxIsolation -> TxT m ()
setTxIsoLevelTx :: forall (m :: * -> *). MonadIO m => TxIsolation -> TxT m ()
setTxIsoLevelTx TxIsolation
txIso =
  Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery (Query -> TxT m ()) -> Query -> TxT m ()
forall a b. (a -> b) -> a -> b
$ Text -> Query
ODBC.rawUnescapedText (Text -> Query) -> Text -> Query
forall a b. (a -> b) -> a -> b
$ Text
"SET TRANSACTION ISOLATION LEVEL " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> TxIsolation -> Text
forall a. Show a => a -> Text
tshow TxIsolation
txIso Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
";"

commitTx :: (MonadIO m) => TxT m ()
commitTx :: forall (m :: * -> *). MonadIO m => TxT m ()
commitTx =
  TxT m TransactionState
forall (m :: * -> *). MonadIO m => TxT m TransactionState
getTransactionState TxT m TransactionState
-> (TransactionState -> TxET MSSQLTxError m ())
-> TxET MSSQLTxError m ()
forall a b.
TxET MSSQLTxError m a
-> (a -> TxET MSSQLTxError m b) -> TxET MSSQLTxError m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TransactionState
TSActive ->
      Query -> TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"COMMIT TRANSACTION"
    TransactionState
TSUncommittable ->
      MSSQLTxError -> TxET MSSQLTxError m ()
forall a. MSSQLTxError -> TxET MSSQLTxError m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MSSQLTxError -> TxET MSSQLTxError m ())
-> MSSQLTxError -> TxET MSSQLTxError m ()
forall a b. (a -> b) -> a -> b
$ Text -> MSSQLTxError
MSSQLInternal Text
"Transaction is uncommittable"
    TransactionState
TSNoActive ->
      MSSQLTxError -> TxET MSSQLTxError m ()
forall a. MSSQLTxError -> TxET MSSQLTxError m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MSSQLTxError -> TxET MSSQLTxError m ())
-> MSSQLTxError -> TxET MSSQLTxError m ()
forall a b. (a -> b) -> a -> b
$ Text -> MSSQLTxError
MSSQLInternal Text
"No active transaction exist; cannot commit"

rollbackTx :: (MonadIO m) => TxT m ()
rollbackTx :: forall (m :: * -> *). MonadIO m => TxT m ()
rollbackTx =
  let rollback :: TxT m ()
rollback = Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"ROLLBACK TRANSACTION"
   in TxT m TransactionState
forall (m :: * -> *). MonadIO m => TxT m TransactionState
getTransactionState TxT m TransactionState
-> (TransactionState -> TxT m ()) -> TxT m ()
forall a b.
TxET MSSQLTxError m a
-> (a -> TxET MSSQLTxError m b) -> TxET MSSQLTxError m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        TransactionState
TSActive -> TxT m ()
rollback
        TransactionState
TSUncommittable -> TxT m ()
rollback
        TransactionState
TSNoActive ->
          -- Some query exceptions result in an auto-rollback of the transaction.
          -- For eg. Creating a table with already existing table name (See https://github.com/hasura/graphql-engine-mono/issues/3046)
          -- In such cases, we shouldn't rollback the transaction again.
          () -> TxT m ()
forall a. a -> TxET MSSQLTxError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Get the @'TransactionState' of current connection
-- For more details, refer to https://docs.microsoft.com/en-us/sql/t-sql/functions/xact-state-transact-sql?view=sql-server-ver15
getTransactionState :: (MonadIO m) => TxT m TransactionState
getTransactionState :: forall (m :: * -> *). MonadIO m => TxT m TransactionState
getTransactionState =
  let query :: Query
query = Query
"SELECT XACT_STATE()"
   in forall a (m :: * -> *). (MonadIO m, FromRow a) => Query -> TxT m a
singleRowQuery @Int Query
query
        TxT m Int
-> (Int -> TxET MSSQLTxError m TransactionState)
-> TxET MSSQLTxError m TransactionState
forall a b.
TxET MSSQLTxError m a
-> (a -> TxET MSSQLTxError m b) -> TxET MSSQLTxError m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Int
1 -> TransactionState -> TxET MSSQLTxError m TransactionState
forall a. a -> TxET MSSQLTxError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSActive
          Int
0 -> TransactionState -> TxET MSSQLTxError m TransactionState
forall a. a -> TxET MSSQLTxError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSNoActive
          -1 -> TransactionState -> TxET MSSQLTxError m TransactionState
forall a. a -> TxET MSSQLTxError m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSUncommittable
          Int
_ ->
            MSSQLTxError -> TxET MSSQLTxError m TransactionState
forall a. MSSQLTxError -> TxET MSSQLTxError m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
              (MSSQLTxError -> TxET MSSQLTxError m TransactionState)
-> MSSQLTxError -> TxET MSSQLTxError m TransactionState
forall a b. (a -> b) -> a -> b
$ Query -> ODBCException -> MSSQLTxError
MSSQLQueryError Query
query
              (ODBCException -> MSSQLTxError) -> ODBCException -> MSSQLTxError
forall a b. (a -> b) -> a -> b
$ String -> ODBCException
ODBC.DataRetrievalError String
"Unexpected value for XACT_STATE"