module Database.MSSQL.Transaction
  ( TxET (..),
    MSSQLTxError (..),
    TxT,
    TxE,
    runTx,
    runTxE,
    unitQuery,
    unitQueryE,
    singleRowQuery,
    singleRowQueryE,
    multiRowQuery,
    multiRowQueryE,
    forJsonQueryE,
    buildGenericQueryTxE,
    withTxET,
  )
where

import Control.Exception (try)
import Control.Monad.Morph (MFunctor (hoist))
import Control.Monad.Trans.Control (MonadBaseControl)
import Database.MSSQL.Pool
import Database.ODBC.SQLServer (FromRow)
import Database.ODBC.SQLServer qualified as ODBC
import Hasura.Prelude

-- | The transaction command to run, parameterised over:
-- e - the exception type (usually 'MSSQLTxError')
-- m - some Monad, (usually some 'MonadIO')
-- a - the successful result type
newtype TxET e m a = TxET
  { TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler :: ReaderT ODBC.Connection (ExceptT e m) a
  }
  deriving
    ( a -> TxET e m b -> TxET e m a
(a -> b) -> TxET e m a -> TxET e m b
(forall a b. (a -> b) -> TxET e m a -> TxET e m b)
-> (forall a b. a -> TxET e m b -> TxET e m a)
-> Functor (TxET e m)
forall a b. a -> TxET e m b -> TxET e m a
forall a b. (a -> b) -> TxET e m a -> TxET e m b
forall e (m :: * -> *) a b.
Functor m =>
a -> TxET e m b -> TxET e m a
forall e (m :: * -> *) a b.
Functor m =>
(a -> b) -> TxET e m a -> TxET e m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TxET e m b -> TxET e m a
$c<$ :: forall e (m :: * -> *) a b.
Functor m =>
a -> TxET e m b -> TxET e m a
fmap :: (a -> b) -> TxET e m a -> TxET e m b
$cfmap :: forall e (m :: * -> *) a b.
Functor m =>
(a -> b) -> TxET e m a -> TxET e m b
Functor,
      Functor (TxET e m)
a -> TxET e m a
Functor (TxET e m)
-> (forall a. a -> TxET e m a)
-> (forall a b. TxET e m (a -> b) -> TxET e m a -> TxET e m b)
-> (forall a b c.
    (a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m b)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m a)
-> Applicative (TxET e m)
TxET e m a -> TxET e m b -> TxET e m b
TxET e m a -> TxET e m b -> TxET e m a
TxET e m (a -> b) -> TxET e m a -> TxET e m b
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
forall a. a -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m b
forall a b. TxET e m (a -> b) -> TxET e m a -> TxET e m b
forall a b c.
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
forall e (m :: * -> *). Monad m => Functor (TxET e m)
forall e (m :: * -> *) a. Monad m => a -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
forall e (m :: * -> *) a b.
Monad m =>
TxET e m (a -> b) -> TxET e m a -> TxET e m b
forall e (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TxET e m a -> TxET e m b -> TxET e m a
$c<* :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m a
*> :: TxET e m a -> TxET e m b -> TxET e m b
$c*> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
liftA2 :: (a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
$cliftA2 :: forall e (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> TxET e m a -> TxET e m b -> TxET e m c
<*> :: TxET e m (a -> b) -> TxET e m a -> TxET e m b
$c<*> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m (a -> b) -> TxET e m a -> TxET e m b
pure :: a -> TxET e m a
$cpure :: forall e (m :: * -> *) a. Monad m => a -> TxET e m a
$cp1Applicative :: forall e (m :: * -> *). Monad m => Functor (TxET e m)
Applicative,
      Applicative (TxET e m)
a -> TxET e m a
Applicative (TxET e m)
-> (forall a b. TxET e m a -> (a -> TxET e m b) -> TxET e m b)
-> (forall a b. TxET e m a -> TxET e m b -> TxET e m b)
-> (forall a. a -> TxET e m a)
-> Monad (TxET e m)
TxET e m a -> (a -> TxET e m b) -> TxET e m b
TxET e m a -> TxET e m b -> TxET e m b
forall a. a -> TxET e m a
forall a b. TxET e m a -> TxET e m b -> TxET e m b
forall a b. TxET e m a -> (a -> TxET e m b) -> TxET e m b
forall e (m :: * -> *). Monad m => Applicative (TxET e m)
forall e (m :: * -> *) a. Monad m => a -> TxET e m a
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> (a -> TxET e m b) -> TxET e m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TxET e m a
$creturn :: forall e (m :: * -> *) a. Monad m => a -> TxET e m a
>> :: TxET e m a -> TxET e m b -> TxET e m b
$c>> :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> TxET e m b -> TxET e m b
>>= :: TxET e m a -> (a -> TxET e m b) -> TxET e m b
$c>>= :: forall e (m :: * -> *) a b.
Monad m =>
TxET e m a -> (a -> TxET e m b) -> TxET e m b
$cp1Monad :: forall e (m :: * -> *). Monad m => Applicative (TxET e m)
Monad,
      MonadError e,
      Monad (TxET e m)
Monad (TxET e m)
-> (forall a. IO a -> TxET e m a) -> MonadIO (TxET e m)
IO a -> TxET e m a
forall a. IO a -> TxET e m a
forall e (m :: * -> *). MonadIO m => Monad (TxET e m)
forall e (m :: * -> *) a. MonadIO m => IO a -> TxET e m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> TxET e m a
$cliftIO :: forall e (m :: * -> *) a. MonadIO m => IO a -> TxET e m a
$cp1MonadIO :: forall e (m :: * -> *). MonadIO m => Monad (TxET e m)
MonadIO,
      MonadReader ODBC.Connection,
      Monad (TxET e m)
Monad (TxET e m)
-> (forall a. (a -> TxET e m a) -> TxET e m a)
-> MonadFix (TxET e m)
(a -> TxET e m a) -> TxET e m a
forall a. (a -> TxET e m a) -> TxET e m a
forall e (m :: * -> *). MonadFix m => Monad (TxET e m)
forall e (m :: * -> *) a.
MonadFix m =>
(a -> TxET e m a) -> TxET e m a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
mfix :: (a -> TxET e m a) -> TxET e m a
$cmfix :: forall e (m :: * -> *) a.
MonadFix m =>
(a -> TxET e m a) -> TxET e m a
$cp1MonadFix :: forall e (m :: * -> *). MonadFix m => Monad (TxET e m)
MonadFix
    )

instance MFunctor (TxET e) where
  hoist :: (forall a. m a -> n a) -> TxET e m b -> TxET e n b
hoist forall a. m a -> n a
f = ReaderT Connection (ExceptT e n) b -> TxET e n b
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e n) b -> TxET e n b)
-> (TxET e m b -> ReaderT Connection (ExceptT e n) b)
-> TxET e m b
-> TxET e n b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. ExceptT e m a -> ExceptT e n a)
-> ReaderT Connection (ExceptT e m) b
-> ReaderT Connection (ExceptT e n) b
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist ((forall a. m a -> n a) -> ExceptT e m a -> ExceptT e n a
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. m a -> n a
f) (ReaderT Connection (ExceptT e m) b
 -> ReaderT Connection (ExceptT e n) b)
-> (TxET e m b -> ReaderT Connection (ExceptT e m) b)
-> TxET e m b
-> ReaderT Connection (ExceptT e n) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TxET e m b -> ReaderT Connection (ExceptT e m) b
forall e (m :: * -> *) a.
TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler

instance MonadTrans (TxET e) where
  lift :: m a -> TxET e m a
lift = ReaderT Connection (ExceptT e m) a -> TxET e m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e m) a -> TxET e m a)
-> (m a -> ReaderT Connection (ExceptT e m) a) -> m a -> TxET e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT e m a -> ReaderT Connection (ExceptT e m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT e m a -> ReaderT Connection (ExceptT e m) a)
-> (m a -> ExceptT e m a)
-> m a
-> ReaderT Connection (ExceptT e m) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ExceptT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- | Error type generally used in 'TxET'.
data MSSQLTxError
  = MSSQLQueryError !ODBC.Query !ODBC.ODBCException
  | MSSQLConnError !ODBC.ODBCException
  | MSSQLInternal !Text
  deriving (MSSQLTxError -> MSSQLTxError -> Bool
(MSSQLTxError -> MSSQLTxError -> Bool)
-> (MSSQLTxError -> MSSQLTxError -> Bool) -> Eq MSSQLTxError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MSSQLTxError -> MSSQLTxError -> Bool
$c/= :: MSSQLTxError -> MSSQLTxError -> Bool
== :: MSSQLTxError -> MSSQLTxError -> Bool
$c== :: MSSQLTxError -> MSSQLTxError -> Bool
Eq, Int -> MSSQLTxError -> ShowS
[MSSQLTxError] -> ShowS
MSSQLTxError -> String
(Int -> MSSQLTxError -> ShowS)
-> (MSSQLTxError -> String)
-> ([MSSQLTxError] -> ShowS)
-> Show MSSQLTxError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MSSQLTxError] -> ShowS
$cshowList :: [MSSQLTxError] -> ShowS
show :: MSSQLTxError -> String
$cshow :: MSSQLTxError -> String
showsPrec :: Int -> MSSQLTxError -> ShowS
$cshowsPrec :: Int -> MSSQLTxError -> ShowS
Show)

type TxE e a = TxET e IO a

-- | The transaction command to run, returning an MSSQLTxError or the result.
type TxT m a = TxET MSSQLTxError m a

-- | Run a command on the given connection wrapped in a transaction.
--
-- See 'runTxE' if you need to map the error type as well.
runTx ::
  (MonadIO m, MonadBaseControl IO m) =>
  TxT m a ->
  MSSQLPool ->
  ExceptT MSSQLTxError m a
runTx :: TxT m a -> MSSQLPool -> ExceptT MSSQLTxError m a
runTx = (MSSQLTxError -> MSSQLTxError)
-> TxT m a -> MSSQLPool -> ExceptT MSSQLTxError m a
forall (m :: * -> *) e a.
(MonadIO m, MonadBaseControl IO m) =>
(MSSQLTxError -> e) -> TxET e m a -> MSSQLPool -> ExceptT e m a
runTxE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Run a command on the given connection wrapped in a transaction.
runTxE ::
  (MonadIO m, MonadBaseControl IO m) =>
  (MSSQLTxError -> e) ->
  TxET e m a ->
  MSSQLPool ->
  ExceptT e m a
runTxE :: (MSSQLTxError -> e) -> TxET e m a -> MSSQLPool -> ExceptT e m a
runTxE MSSQLTxError -> e
ef TxET e m a
tx MSSQLPool
pool = do
  MSSQLPool
-> (Connection -> ExceptT e m a)
-> ExceptT e m (Either ODBCException a)
forall (m :: * -> *) a.
MonadBaseControl IO m =>
MSSQLPool -> (Connection -> m a) -> m (Either ODBCException a)
withMSSQLPool MSSQLPool
pool ((MSSQLTxError -> e)
-> (Connection -> ExceptT e m a) -> Connection -> ExceptT e m a
forall e a (m :: * -> *).
MonadIO m =>
(MSSQLTxError -> e)
-> (Connection -> ExceptT e m a) -> Connection -> ExceptT e m a
asTransaction MSSQLTxError -> e
ef (Connection -> TxET e m a -> ExceptT e m a
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
`execTx` TxET e m a
tx))
    ExceptT e m (Either ODBCException a)
-> (Either ODBCException a -> ExceptT e m a) -> ExceptT e m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either e a -> ExceptT e m a
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> ExceptT e m a
hoistEither (Either e a -> ExceptT e m a)
-> (Either ODBCException a -> Either e a)
-> Either ODBCException a
-> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ODBCException -> e) -> Either ODBCException a -> Either e a
forall e1 e2 a. (e1 -> e2) -> Either e1 a -> Either e2 a
mapLeft (MSSQLTxError -> e
ef (MSSQLTxError -> e)
-> (ODBCException -> MSSQLTxError) -> ODBCException -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ODBCException -> MSSQLTxError
MSSQLConnError)

-- | Useful for building transactions which return no data.
--
-- @
-- insertId :: TxT m ()
-- insertId = unitQuery "INSERT INTO some_table VALUES (1, \"hello\")"
-- @
--
-- See 'unitQueryE' if you need to map the error type as well.
unitQuery :: (MonadIO m) => ODBC.Query -> TxT m ()
unitQuery :: Query -> TxT m ()
unitQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxT m ()
forall (m :: * -> *) e.
MonadIO m =>
(MSSQLTxError -> e) -> Query -> TxET e m ()
unitQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building transactions which return no data.
unitQueryE :: (MonadIO m) => (MSSQLTxError -> e) -> ODBC.Query -> TxET e m ()
unitQueryE :: (MSSQLTxError -> e) -> Query -> TxET e m ()
unitQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String ()) -> Query -> TxET e m ()
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String ()
emptyResult
  where
    emptyResult :: MSSQLResult -> Either String ()
    emptyResult :: MSSQLResult -> Either String ()
emptyResult (MSSQLResult []) = () -> Either String ()
forall a b. b -> Either a b
Right ()
    emptyResult (MSSQLResult [[Value]]
_) = String -> Either String ()
forall a b. a -> Either a b
Left String
"expecting no data for ()"

-- | Useful for building query transactions which return a single one row.
--
-- @
-- returnOne :: TxT m Int
-- returnOne = singleRowQuery "SELECT 1"
-- @
--
-- See 'singleRowQueryE' if you need to map the error type as well.
singleRowQuery :: forall a m. (MonadIO m, FromRow a) => ODBC.Query -> TxT m a
singleRowQuery :: Query -> TxT m a
singleRowQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxT m a
forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m a
singleRowQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building query transactions which return a single one row.
singleRowQueryE ::
  forall m a e.
  (MonadIO m, FromRow a) =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m a
singleRowQueryE :: (MSSQLTxError -> e) -> Query -> TxET e m a
singleRowQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String a
singleRowResult
  where
    singleRowResult :: MSSQLResult -> Either String a
    singleRowResult :: MSSQLResult -> Either String a
singleRowResult (MSSQLResult [[Value]
row]) = [Value] -> Either String a
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [Value]
row
    singleRowResult (MSSQLResult [[Value]]
_) = String -> Either String a
forall a b. a -> Either a b
Left String
"expecting single row"

-- | MSSQL splits up results that have a @SELECT .. FOR JSON@ at the top-level
-- into multiple rows with a single column, see
-- https://docs.microsoft.com/en-us/sql/relational-databases/json/format-query-results-as-json-with-for-json-sql-server?view=sql-server-ver15#output-of-the-for-json-clause
--
-- This function simply concatenates each single-column row into one long 'Text' string.
forJsonQueryE ::
  forall m e.
  MonadIO m =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m Text
forJsonQueryE :: (MSSQLTxError -> e) -> Query -> TxET e m Text
forJsonQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String Text) -> Query -> TxET e m Text
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String Text
concatRowResult
  where
    concatRowResult :: MSSQLResult -> Either String Text
    concatRowResult :: MSSQLResult -> Either String Text
concatRowResult (MSSQLResult []) = Text -> Either String Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
forall a. Monoid a => a
mempty
    concatRowResult (MSSQLResult rows :: [[Value]]
rows@([Value]
r1 : [[Value]]
_)) | [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
r1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> Either String [Text] -> Either String Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Value] -> Either String Text)
-> [[Value]] -> Either String [Text]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [Value] -> Either String Text
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [[Value]]
rows
    concatRowResult (MSSQLResult ([Value]
r1 : [[Value]]
_)) = String -> Either String Text
forall a b. a -> Either a b
Left (String -> Either String Text) -> String -> Either String Text
forall a b. (a -> b) -> a -> b
$ String
"forJsonQueryE: Expected single-column results, but got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show ([Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
r1) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" columns"

-- | Useful for building query transactions which return multiple rows.
--
-- @
-- selectIds :: TxT m [Int]
-- selectIds = multiRowQuery "SELECT id FROM author"
-- @
--
-- See 'multiRowQueryE' if you need to map the error type as well.
multiRowQuery :: forall a m. (MonadIO m, FromRow a) => ODBC.Query -> TxT m [a]
multiRowQuery :: Query -> TxT m [a]
multiRowQuery = (MSSQLTxError -> MSSQLTxError) -> Query -> TxT m [a]
forall (m :: * -> *) a e.
(MonadIO m, FromRow a) =>
(MSSQLTxError -> e) -> Query -> TxET e m [a]
multiRowQueryE MSSQLTxError -> MSSQLTxError
forall a. a -> a
id

-- | Useful for building query transactions which return multiple rows.
multiRowQueryE ::
  forall m a e.
  (MonadIO m, FromRow a) =>
  (MSSQLTxError -> e) ->
  ODBC.Query ->
  TxET e m [a]
multiRowQueryE :: (MSSQLTxError -> e) -> Query -> TxET e m [a]
multiRowQueryE MSSQLTxError -> e
ef = (MSSQLTxError -> e)
-> (MSSQLResult -> Either String [a]) -> Query -> TxET e m [a]
forall (m :: * -> *) e a.
MonadIO m =>
(MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String [a]
multiRowResult
  where
    multiRowResult :: MSSQLResult -> Either String [a]
    multiRowResult :: MSSQLResult -> Either String [a]
multiRowResult (MSSQLResult [[Value]]
rows) = ([Value] -> Either String a) -> [[Value]] -> Either String [a]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse [Value] -> Either String a
forall r. FromRow r => [Value] -> Either String r
ODBC.fromRow [[Value]]
rows

-- | Build a generic transaction out of an IO action.
buildGenericQueryTxE ::
  (MonadIO m) =>
  -- | map 'MSSQLTxError' to some other type
  (MSSQLTxError -> e) ->
  -- | query to run
  query ->
  -- | how to map a query to a 'ODBC.Query'
  (query -> ODBC.Query) ->
  -- | run the query on a provided 'ODBC.Connection'
  (ODBC.Connection -> query -> IO a) ->
  TxET e m a
buildGenericQueryTxE :: (MSSQLTxError -> e)
-> query
-> (query -> Query)
-> (Connection -> query -> IO a)
-> TxET e m a
buildGenericQueryTxE MSSQLTxError -> e
errorF query
query query -> Query
convertQ Connection -> query -> IO a
runQuery =
  ReaderT Connection (ExceptT e m) a -> TxET e m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e m) a -> TxET e m a)
-> ReaderT Connection (ExceptT e m) a -> TxET e m a
forall a b. (a -> b) -> a -> b
$ (Connection -> ExceptT e m a) -> ReaderT Connection (ExceptT e m) a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((Connection -> ExceptT e m a)
 -> ReaderT Connection (ExceptT e m) a)
-> (Connection -> ExceptT e m a)
-> ReaderT Connection (ExceptT e m) a
forall a b. (a -> b) -> a -> b
$ (MSSQLTxError -> e) -> ExceptT MSSQLTxError m a -> ExceptT e m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
errorF (ExceptT MSSQLTxError m a -> ExceptT e m a)
-> (Connection -> ExceptT MSSQLTxError m a)
-> Connection
-> ExceptT e m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
forall (m :: * -> *) a query.
MonadIO m =>
query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
execQuery query
query query -> Query
convertQ ((query -> IO a) -> ExceptT MSSQLTxError m a)
-> (Connection -> query -> IO a)
-> Connection
-> ExceptT MSSQLTxError m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> query -> IO a
runQuery

-- | Map the error type for a 'TxET'.
withTxET :: Monad m => (e1 -> e2) -> TxET e1 m a -> TxET e2 m a
withTxET :: (e1 -> e2) -> TxET e1 m a -> TxET e2 m a
withTxET e1 -> e2
f (TxET ReaderT Connection (ExceptT e1 m) a
m) = ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a
forall e (m :: * -> *) a.
ReaderT Connection (ExceptT e m) a -> TxET e m a
TxET (ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a)
-> ReaderT Connection (ExceptT e2 m) a -> TxET e2 m a
forall a b. (a -> b) -> a -> b
$ (forall a. ExceptT e1 m a -> ExceptT e2 m a)
-> ReaderT Connection (ExceptT e1 m) a
-> ReaderT Connection (ExceptT e2 m) a
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist ((e1 -> e2) -> ExceptT e1 m a -> ExceptT e2 m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT e1 -> e2
f) ReaderT Connection (ExceptT e1 m) a
m

-- | A successful result from a query is a list of rows where each row contains
-- list of column values
newtype MSSQLResult = MSSQLResult [[ODBC.Value]]
  deriving (MSSQLResult -> MSSQLResult -> Bool
(MSSQLResult -> MSSQLResult -> Bool)
-> (MSSQLResult -> MSSQLResult -> Bool) -> Eq MSSQLResult
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MSSQLResult -> MSSQLResult -> Bool
$c/= :: MSSQLResult -> MSSQLResult -> Bool
== :: MSSQLResult -> MSSQLResult -> Bool
$c== :: MSSQLResult -> MSSQLResult -> Bool
Eq, Int -> MSSQLResult -> ShowS
[MSSQLResult] -> ShowS
MSSQLResult -> String
(Int -> MSSQLResult -> ShowS)
-> (MSSQLResult -> String)
-> ([MSSQLResult] -> ShowS)
-> Show MSSQLResult
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MSSQLResult] -> ShowS
$cshowList :: [MSSQLResult] -> ShowS
show :: MSSQLResult -> String
$cshow :: MSSQLResult -> String
showsPrec :: Int -> MSSQLResult -> ShowS
$cshowsPrec :: Int -> MSSQLResult -> ShowS
Show)

-- | Packs a query, along with result and error converters into a 'TxET'.
--
-- Used by 'unitQueryE', 'singleRowQueryE', and 'multiRowQueryE'.
rawQueryE ::
  (MonadIO m) =>
  -- | Error modifier
  (MSSQLTxError -> e) ->
  -- | Result modifier with a failure
  (MSSQLResult -> Either String a) ->
  -- | Query to run
  ODBC.Query ->
  TxET e m a
rawQueryE :: (MSSQLTxError -> e)
-> (MSSQLResult -> Either String a) -> Query -> TxET e m a
rawQueryE MSSQLTxError -> e
ef MSSQLResult -> Either String a
rf Query
q = do
  [[Value]]
rows <- (MSSQLTxError -> e)
-> Query
-> (Query -> Query)
-> (Connection -> Query -> IO [[Value]])
-> TxET e m [[Value]]
forall (m :: * -> *) e query a.
MonadIO m =>
(MSSQLTxError -> e)
-> query
-> (query -> Query)
-> (Connection -> query -> IO a)
-> TxET e m a
buildGenericQueryTxE MSSQLTxError -> e
ef Query
q Query -> Query
forall a. a -> a
id Connection -> Query -> IO [[Value]]
forall (m :: * -> *) row.
(MonadIO m, FromRow row) =>
Connection -> Query -> m [row]
ODBC.query
  Either e a -> TxET e m a
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either e a -> TxET e m a) -> Either e a -> TxET e m a
forall a b. (a -> b) -> a -> b
$
    (String -> e) -> Either String a -> Either e a
forall e1 e2 a. (e1 -> e2) -> Either e1 a -> Either e2 a
mapLeft (MSSQLTxError -> e
ef (MSSQLTxError -> e) -> (String -> MSSQLTxError) -> String -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Query -> ODBCException -> MSSQLTxError
MSSQLQueryError Query
q (ODBCException -> MSSQLTxError)
-> (String -> ODBCException) -> String -> MSSQLTxError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ODBCException
ODBC.DataRetrievalError) (Either String a -> Either e a) -> Either String a -> Either e a
forall a b. (a -> b) -> a -> b
$
      MSSQLResult -> Either String a
rf ([[Value]] -> MSSQLResult
MSSQLResult [[Value]]
rows)

-- | Combinator for abstracting over the query type and ensuring we catch exceptions.
--
-- Used by 'buildGenericQueryTxE'.
execQuery ::
  forall m a query.
  (MonadIO m) =>
  query ->
  (query -> ODBC.Query) ->
  (query -> IO a) ->
  ExceptT MSSQLTxError m a
execQuery :: query
-> (query -> Query) -> (query -> IO a) -> ExceptT MSSQLTxError m a
execQuery query
query query -> Query
toODBCQuery query -> IO a
runQuery = do
  Either ODBCException a
result :: Either ODBC.ODBCException a <- IO (Either ODBCException a)
-> ExceptT MSSQLTxError m (Either ODBCException a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either ODBCException a)
 -> ExceptT MSSQLTxError m (Either ODBCException a))
-> IO (Either ODBCException a)
-> ExceptT MSSQLTxError m (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ IO a -> IO (Either ODBCException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO (Either ODBCException a))
-> IO a -> IO (Either ODBCException a)
forall a b. (a -> b) -> a -> b
$ query -> IO a
runQuery query
query
  (ODBCException -> MSSQLTxError)
-> ExceptT ODBCException m a -> ExceptT MSSQLTxError m a
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT (Query -> ODBCException -> MSSQLTxError
MSSQLQueryError (Query -> ODBCException -> MSSQLTxError)
-> Query -> ODBCException -> MSSQLTxError
forall a b. (a -> b) -> a -> b
$ query -> Query
toODBCQuery query
query) (ExceptT ODBCException m a -> ExceptT MSSQLTxError m a)
-> ExceptT ODBCException m a -> ExceptT MSSQLTxError m a
forall a b. (a -> b) -> a -> b
$ Either ODBCException a -> ExceptT ODBCException m a
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> ExceptT e m a
hoistEither Either ODBCException a
result

-- | Run a 'TxET' with the given connection.
--
-- Used by 'runTxE' and 'asTransaction'.
execTx :: ODBC.Connection -> TxET e m a -> ExceptT e m a
execTx :: Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET e m a
tx = ReaderT Connection (ExceptT e m) a -> Connection -> ExceptT e m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (TxET e m a -> ReaderT Connection (ExceptT e m) a
forall e (m :: * -> *) a.
TxET e m a -> ReaderT Connection (ExceptT e m) a
txHandler TxET e m a
tx) Connection
conn
{-# INLINE execTx #-}

-- | The transaction state of the current connection
data TransactionState
  = -- | Has an active transaction.
    TSActive
  | -- | Has no active transaction.
    TSNoActive
  | -- | An error occurred that caused the transaction to be uncommittable.
    -- We cannot commit or rollback to a savepoint; we can only do a full
    -- rollback of the transaction.
    TSUncommittable

-- | Wraps an action in a transaction. Rolls back on errors.
asTransaction ::
  forall e a m.
  MonadIO m =>
  (MSSQLTxError -> e) ->
  (ODBC.Connection -> ExceptT e m a) ->
  ODBC.Connection ->
  ExceptT e m a
asTransaction :: (MSSQLTxError -> e)
-> (Connection -> ExceptT e m a) -> Connection -> ExceptT e m a
asTransaction MSSQLTxError -> e
ef Connection -> ExceptT e m a
action Connection
conn = do
  -- Begin the transaction. If there is an error, do not rollback.
  (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
beginTx
  -- Run the transaction and commit. If there is an error, rollback.
  (ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a)
-> (e -> ExceptT e m a) -> ExceptT e m a -> ExceptT e m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ExceptT e m a -> (e -> ExceptT e m a) -> ExceptT e m a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError e -> ExceptT e m a
forall b. e -> ExceptT e m b
rollbackAndThrow do
    a
result <- Connection -> ExceptT e m a
action Connection
conn
    (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
commitTx
    a -> ExceptT e m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
result
  where
    -- Rollback and throw error.
    rollbackAndThrow :: e -> ExceptT e m b
    rollbackAndThrow :: e -> ExceptT e m b
rollbackAndThrow e
err = do
      (MSSQLTxError -> e) -> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT MSSQLTxError -> e
ef (ExceptT MSSQLTxError m () -> ExceptT e m ())
-> ExceptT MSSQLTxError m () -> ExceptT e m ()
forall a b. (a -> b) -> a -> b
$ Connection -> TxET MSSQLTxError m () -> ExceptT MSSQLTxError m ()
forall e (m :: * -> *) a. Connection -> TxET e m a -> ExceptT e m a
execTx Connection
conn TxET MSSQLTxError m ()
forall (m :: * -> *). MonadIO m => TxT m ()
rollbackTx
      e -> ExceptT e m b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
err

beginTx :: MonadIO m => TxT m ()
beginTx :: TxT m ()
beginTx = Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"BEGIN TRANSACTION"

commitTx :: MonadIO m => TxT m ()
commitTx :: TxT m ()
commitTx =
  TxT m TransactionState
forall (m :: * -> *). MonadIO m => TxT m TransactionState
getTransactionState TxT m TransactionState
-> (TransactionState -> TxT m ()) -> TxT m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    TransactionState
TSActive ->
      Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"COMMIT TRANSACTION"
    TransactionState
TSUncommittable ->
      MSSQLTxError -> TxT m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MSSQLTxError -> TxT m ()) -> MSSQLTxError -> TxT m ()
forall a b. (a -> b) -> a -> b
$ Text -> MSSQLTxError
MSSQLInternal Text
"Transaction is uncommittable"
    TransactionState
TSNoActive ->
      MSSQLTxError -> TxT m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MSSQLTxError -> TxT m ()) -> MSSQLTxError -> TxT m ()
forall a b. (a -> b) -> a -> b
$ Text -> MSSQLTxError
MSSQLInternal Text
"No active transaction exist; cannot commit"

rollbackTx :: MonadIO m => TxT m ()
rollbackTx :: TxT m ()
rollbackTx =
  let rollback :: TxT m ()
rollback = Query -> TxT m ()
forall (m :: * -> *). MonadIO m => Query -> TxT m ()
unitQuery Query
"ROLLBACK TRANSACTION"
   in TxT m TransactionState
forall (m :: * -> *). MonadIO m => TxT m TransactionState
getTransactionState TxT m TransactionState
-> (TransactionState -> TxT m ()) -> TxT m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        TransactionState
TSActive -> TxT m ()
rollback
        TransactionState
TSUncommittable -> TxT m ()
rollback
        TransactionState
TSNoActive ->
          -- Some query exceptions result in an auto-rollback of the transaction.
          -- For eg. Creating a table with already existing table name (See https://github.com/hasura/graphql-engine-mono/issues/3046)
          -- In such cases, we shouldn't rollback the transaction again.
          () -> TxT m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Get the @'TransactionState' of current connection
-- For more details, refer to https://docs.microsoft.com/en-us/sql/t-sql/functions/xact-state-transact-sql?view=sql-server-ver15
getTransactionState :: (MonadIO m) => TxT m TransactionState
getTransactionState :: TxT m TransactionState
getTransactionState =
  let query :: Query
query = Query
"SELECT XACT_STATE()"
   in Query -> TxT m Int
forall a (m :: * -> *). (MonadIO m, FromRow a) => Query -> TxT m a
singleRowQuery @Int Query
query
        TxT m Int
-> (Int -> TxT m TransactionState) -> TxT m TransactionState
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Int
1 -> TransactionState -> TxT m TransactionState
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSActive
          Int
0 -> TransactionState -> TxT m TransactionState
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSNoActive
          -1 -> TransactionState -> TxT m TransactionState
forall (f :: * -> *) a. Applicative f => a -> f a
pure TransactionState
TSUncommittable
          Int
_ ->
            MSSQLTxError -> TxT m TransactionState
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (MSSQLTxError -> TxT m TransactionState)
-> MSSQLTxError -> TxT m TransactionState
forall a b. (a -> b) -> a -> b
$
              Query -> ODBCException -> MSSQLTxError
MSSQLQueryError Query
query (ODBCException -> MSSQLTxError) -> ODBCException -> MSSQLTxError
forall a b. (a -> b) -> a -> b
$
                String -> ODBCException
ODBC.DataRetrievalError String
"Unexpected value for XACT_STATE"