-- | How to construct and execute a call to a source for a remote join.
--
-- There are three steps required to do this:
--   - construct the execution step for that source join
--   - execute that GraphQL query over the network
--   - build a join index of the variables out of the response
--
-- This can be done as one function, but we also export the individual steps for
-- debugging / test purposes. We congregate all intermediary state in the opaque
-- 'SourceJoinCall' type.
module Hasura.GraphQL.Execute.RemoteJoin.Source
  ( -- * Executing a remote join
    makeSourceJoinCall,

    -- * Individual steps
    SourceJoinCall (..),
    buildSourceJoinCall,
    buildJoinIndex,
  )
where

import Data.Aeson qualified as J
import Data.Aeson.Key qualified as K
import Data.Aeson.KeyMap qualified as KM
import Data.Aeson.Ordered qualified as AO
import Data.Aeson.Ordered qualified as JO
import Data.Bifunctor (bimap)
import Data.ByteString.Lazy qualified as BL
import Data.HashMap.Strict.Extended qualified as Map
import Data.IntMap.Strict qualified as IntMap
import Data.List.NonEmpty qualified as NE
import Data.Scientific qualified as Scientific
import Data.Text qualified as T
import Data.Text.Read qualified as TR
import Hasura.Base.Error
import Hasura.GraphQL.Execute.Backend qualified as EB
import Hasura.GraphQL.Execute.Instances ()
import Hasura.GraphQL.Execute.RemoteJoin.Types
import Hasura.GraphQL.Namespace
import Hasura.GraphQL.Transport.Instances ()
import Hasura.Prelude
import Hasura.RQL.Types.Backend
import Hasura.RQL.Types.Common
import Hasura.SQL.AnyBackend qualified as AB
import Hasura.Session
import Language.GraphQL.Draft.Syntax qualified as G

-------------------------------------------------------------------------------
-- Executing a remote join

-- | Construct and execute a call to a source for a remote join.
makeSourceJoinCall ::
  (EB.MonadQueryTags m, MonadError QErr m) =>
  -- | Function to dispatch a request to a source.
  (AB.AnyBackend SourceJoinCall -> m BL.ByteString) ->
  -- | User information.
  UserInfo ->
  -- | Remote join information.
  AB.AnyBackend RemoteSourceJoin ->
  -- | Name of the field from the join arguments.
  FieldName ->
  -- | Mapping from 'JoinArgumentId' to its corresponding 'JoinArgument'.
  IntMap.IntMap JoinArgument ->
  -- | The resulting join index (see 'buildJoinIndex') if any.
  m (Maybe (IntMap.IntMap AO.Value))
makeSourceJoinCall :: (AnyBackend SourceJoinCall -> m ByteString)
-> UserInfo
-> AnyBackend RemoteSourceJoin
-> FieldName
-> IntMap JoinArgument
-> m (Maybe (IntMap Value))
makeSourceJoinCall AnyBackend SourceJoinCall -> m ByteString
networkFunction UserInfo
userInfo AnyBackend RemoteSourceJoin
remoteSourceJoin FieldName
jaFieldName IntMap JoinArgument
joinArguments = do
  -- step 1: create the SourceJoinCall
  -- maybeSourceCall <-
  --   AB.dispatchAnyBackend @EB.BackendExecute remoteSourceJoin \(sjc :: SourceJoinCall b) ->
  --     buildSourceJoinCall @b userInfo jaFieldName joinArguments sjc
  Maybe (AnyBackend SourceJoinCall)
maybeSourceCall <-
    AnyBackend RemoteSourceJoin
-> (forall (b :: BackendType).
    BackendExecute b =>
    RemoteSourceJoin b -> m (Maybe (AnyBackend SourceJoinCall)))
-> m (Maybe (AnyBackend SourceJoinCall))
forall (c :: BackendType -> Constraint) (i :: BackendType -> *) r.
AllBackendsSatisfy c =>
AnyBackend i -> (forall (b :: BackendType). c b => i b -> r) -> r
AB.dispatchAnyBackend @EB.BackendExecute AnyBackend RemoteSourceJoin
remoteSourceJoin ((forall (b :: BackendType).
  BackendExecute b =>
  RemoteSourceJoin b -> m (Maybe (AnyBackend SourceJoinCall)))
 -> m (Maybe (AnyBackend SourceJoinCall)))
-> (forall (b :: BackendType).
    BackendExecute b =>
    RemoteSourceJoin b -> m (Maybe (AnyBackend SourceJoinCall)))
-> m (Maybe (AnyBackend SourceJoinCall))
forall a b. (a -> b) -> a -> b
$
      UserInfo
-> FieldName
-> IntMap JoinArgument
-> RemoteSourceJoin b
-> m (Maybe (AnyBackend SourceJoinCall))
forall (b :: BackendType) (m :: * -> *).
(BackendExecute b, MonadQueryTags m, MonadError QErr m) =>
UserInfo
-> FieldName
-> IntMap JoinArgument
-> RemoteSourceJoin b
-> m (Maybe (AnyBackend SourceJoinCall))
buildSourceJoinCall UserInfo
userInfo FieldName
jaFieldName IntMap JoinArgument
joinArguments
  -- if there actually is a remote call:
  Maybe (AnyBackend SourceJoinCall)
-> (AnyBackend SourceJoinCall -> m (IntMap Value))
-> m (Maybe (IntMap Value))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for Maybe (AnyBackend SourceJoinCall)
maybeSourceCall \AnyBackend SourceJoinCall
sourceCall -> do
    -- step 2: send this call over the network
    ByteString
sourceResponse <- AnyBackend SourceJoinCall -> m ByteString
networkFunction AnyBackend SourceJoinCall
sourceCall
    -- step 3: build the join index
    ByteString -> m (IntMap Value)
forall (m :: * -> *).
MonadError QErr m =>
ByteString -> m (IntMap Value)
buildJoinIndex ByteString
sourceResponse

-------------------------------------------------------------------------------
-- Internal representation

-- | Intermediate type that contains all the necessary information to perform a
-- call to a database to perform a join.
data SourceJoinCall b = SourceJoinCall
  { SourceJoinCall b -> RootFieldAlias
_sjcRootFieldAlias :: RootFieldAlias,
    SourceJoinCall b -> SourceConfig b
_sjcSourceConfig :: SourceConfig b,
    SourceJoinCall b -> DBStepInfo b
_sjcStepInfo :: EB.DBStepInfo b
  }

-------------------------------------------------------------------------------
-- Step 1: building the source call

buildSourceJoinCall ::
  (EB.BackendExecute b, EB.MonadQueryTags m, MonadError QErr m) =>
  UserInfo ->
  FieldName ->
  IntMap.IntMap JoinArgument ->
  RemoteSourceJoin b ->
  m (Maybe (AB.AnyBackend SourceJoinCall))
buildSourceJoinCall :: UserInfo
-> FieldName
-> IntMap JoinArgument
-> RemoteSourceJoin b
-> m (Maybe (AnyBackend SourceJoinCall))
buildSourceJoinCall UserInfo
userInfo FieldName
jaFieldName IntMap JoinArgument
joinArguments RemoteSourceJoin b
remoteSourceJoin = do
  let rows :: [KeyMap Value]
rows =
        IntMap JoinArgument -> [(Key, JoinArgument)]
forall a. IntMap a -> [(Key, a)]
IntMap.toList IntMap JoinArgument
joinArguments [(Key, JoinArgument)]
-> ((Key, JoinArgument) -> KeyMap Value) -> [KeyMap Value]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Key
argumentId, JoinArgument
argument) ->
          Key -> Value -> KeyMap Value -> KeyMap Value
forall v. Key -> v -> KeyMap v -> KeyMap v
KM.insert Key
"__argument_id__" (Key -> Value
forall a. ToJSON a => a -> Value
J.toJSON Key
argumentId) (KeyMap Value -> KeyMap Value) -> KeyMap Value -> KeyMap Value
forall a b. (a -> b) -> a -> b
$
            [(Key, Value)] -> KeyMap Value
forall v. [(Key, v)] -> KeyMap v
KM.fromList ([(Key, Value)] -> KeyMap Value) -> [(Key, Value)] -> KeyMap Value
forall a b. (a -> b) -> a -> b
$
              ((FieldName, Value) -> (Key, Value))
-> [(FieldName, Value)] -> [(Key, Value)]
forall a b. (a -> b) -> [a] -> [b]
map ((FieldName -> Key)
-> (Value -> Value) -> (FieldName, Value) -> (Key, Value)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (Text -> Key
K.fromText (Text -> Key) -> (FieldName -> Text) -> FieldName -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FieldName -> Text
getFieldNameTxt) Value -> Value
JO.fromOrdered) ([(FieldName, Value)] -> [(Key, Value)])
-> [(FieldName, Value)] -> [(Key, Value)]
forall a b. (a -> b) -> a -> b
$
                HashMap FieldName Value -> [(FieldName, Value)]
forall k v. HashMap k v -> [(k, v)]
Map.toList (HashMap FieldName Value -> [(FieldName, Value)])
-> HashMap FieldName Value -> [(FieldName, Value)]
forall a b. (a -> b) -> a -> b
$
                  JoinArgument -> HashMap FieldName Value
unJoinArgument JoinArgument
argument
      rowSchema :: HashMap FieldName (Column b, ScalarType b)
rowSchema = ((JoinColumnAlias, (Column b, ScalarType b))
 -> (Column b, ScalarType b))
-> HashMap FieldName (JoinColumnAlias, (Column b, ScalarType b))
-> HashMap FieldName (Column b, ScalarType b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (JoinColumnAlias, (Column b, ScalarType b))
-> (Column b, ScalarType b)
forall a b. (a, b) -> b
snd (RemoteSourceJoin b
-> HashMap FieldName (JoinColumnAlias, (Column b, ScalarType b))
forall (b :: BackendType).
RemoteSourceJoin b
-> HashMap FieldName (JoinColumnAlias, (Column b, ScalarType b))
_rsjJoinColumns RemoteSourceJoin b
remoteSourceJoin)
  Maybe (NonEmpty (KeyMap Value))
-> (NonEmpty (KeyMap Value) -> m (AnyBackend SourceJoinCall))
-> m (Maybe (AnyBackend SourceJoinCall))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for ([KeyMap Value] -> Maybe (NonEmpty (KeyMap Value))
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty [KeyMap Value]
rows) ((NonEmpty (KeyMap Value) -> m (AnyBackend SourceJoinCall))
 -> m (Maybe (AnyBackend SourceJoinCall)))
-> (NonEmpty (KeyMap Value) -> m (AnyBackend SourceJoinCall))
-> m (Maybe (AnyBackend SourceJoinCall))
forall a b. (a -> b) -> a -> b
$ \NonEmpty (KeyMap Value)
nonEmptyRows -> do
    let sourceConfig :: SourceConfig b
sourceConfig = RemoteSourceJoin b -> SourceConfig b
forall (b :: BackendType). RemoteSourceJoin b -> SourceConfig b
_rsjSourceConfig RemoteSourceJoin b
remoteSourceJoin
    DBStepInfo b
stepInfo <-
      UserInfo
-> SourceName
-> SourceConfig b
-> NonEmpty (KeyMap Value)
-> HashMap FieldName (Column b, ScalarType b)
-> FieldName
-> (FieldName, SourceRelationshipSelection b Void UnpreparedValue)
-> m (DBStepInfo b)
forall (b :: BackendType) (m :: * -> *).
(BackendExecute b, MonadError QErr m, MonadQueryTags m) =>
UserInfo
-> SourceName
-> SourceConfig b
-> NonEmpty (KeyMap Value)
-> HashMap FieldName (Column b, ScalarType b)
-> FieldName
-> (FieldName, SourceRelationshipSelection b Void UnpreparedValue)
-> m (DBStepInfo b)
EB.mkDBRemoteRelationshipPlan
        UserInfo
userInfo
        (RemoteSourceJoin b -> SourceName
forall (b :: BackendType). RemoteSourceJoin b -> SourceName
_rsjSource RemoteSourceJoin b
remoteSourceJoin)
        SourceConfig b
sourceConfig
        NonEmpty (KeyMap Value)
nonEmptyRows
        HashMap FieldName (Column b, ScalarType b)
rowSchema
        (Text -> FieldName
FieldName Text
"__argument_id__")
        (Text -> FieldName
FieldName Text
"f", RemoteSourceJoin b
-> SourceRelationshipSelection b Void UnpreparedValue
forall (b :: BackendType).
RemoteSourceJoin b
-> SourceRelationshipSelection b Void UnpreparedValue
_rsjRelationship RemoteSourceJoin b
remoteSourceJoin)
    -- This should never fail, as field names in remote relationships are
    -- validated when building the schema cache.
    Name
fieldName <-
      Text -> Maybe Name
G.mkName (FieldName -> Text
getFieldNameTxt FieldName
jaFieldName)
        Maybe Name -> m Name -> m Name
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` Text -> m Name
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text
"'" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> FieldName -> Text
getFieldNameTxt FieldName
jaFieldName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"' is not a valid GraphQL name")
    -- NOTE: We're making an assumption that the 'FieldName' propagated upwards
    -- from 'collectJoinArguments' is reasonable to use for logging.
    let rootFieldAlias :: RootFieldAlias
rootFieldAlias = Name -> RootFieldAlias
mkUnNamespacedRootFieldAlias Name
fieldName
    AnyBackend SourceJoinCall -> m (AnyBackend SourceJoinCall)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AnyBackend SourceJoinCall -> m (AnyBackend SourceJoinCall))
-> AnyBackend SourceJoinCall -> m (AnyBackend SourceJoinCall)
forall a b. (a -> b) -> a -> b
$
      SourceJoinCall b -> AnyBackend SourceJoinCall
forall (b :: BackendType) (i :: BackendType -> *).
HasTag b =>
i b -> AnyBackend i
AB.mkAnyBackend (SourceJoinCall b -> AnyBackend SourceJoinCall)
-> SourceJoinCall b -> AnyBackend SourceJoinCall
forall a b. (a -> b) -> a -> b
$
        RootFieldAlias
-> SourceConfig b -> DBStepInfo b -> SourceJoinCall b
forall (b :: BackendType).
RootFieldAlias
-> SourceConfig b -> DBStepInfo b -> SourceJoinCall b
SourceJoinCall RootFieldAlias
rootFieldAlias SourceConfig b
sourceConfig DBStepInfo b
stepInfo

-------------------------------------------------------------------------------
-- Step 3: extracting the join index

-- | Construct a join index from the 'EncJSON' response from the source.
--
-- Unlike with remote schemas, we can make assumptions about the shape of the
-- result, instead of having to keep track of the path within the answer. This
-- function therefore enforces that the answer has the shape we expect, and
-- throws a 'QErr' if it doesn't.
buildJoinIndex :: (MonadError QErr m) => BL.ByteString -> m (IntMap.IntMap JO.Value)
buildJoinIndex :: ByteString -> m (IntMap Value)
buildJoinIndex ByteString
response = do
  Value
json <-
    ByteString -> Either String Value
JO.eitherDecode ByteString
response {-( response)-} Either String Value -> (String -> m Value) -> m Value
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
`onLeft` \String
err ->
      Text -> m Value
forall (m :: * -> *) a. QErrM m => Text -> m a
throwInvalidJsonErr (Text -> m Value) -> Text -> m Value
forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack String
err
  case Value
json of
    JO.Array Array
arr -> ([(Key, Value)] -> IntMap Value)
-> m [(Key, Value)] -> m (IntMap Value)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Key, Value)] -> IntMap Value
forall a. [(Key, a)] -> IntMap a
IntMap.fromList (m [(Key, Value)] -> m (IntMap Value))
-> m [(Key, Value)] -> m (IntMap Value)
forall a b. (a -> b) -> a -> b
$ [Value] -> (Value -> m (Key, Value)) -> m [(Key, Value)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (Array -> [Value]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Array
arr) \case
      JO.Object Object
obj -> do
        Value
argumentResult <-
          Text -> Object -> Maybe Value
JO.lookup Text
"f" Object
obj
            Maybe Value -> m Value -> m Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` m Value
forall a. m a
throwMissingRelationshipDataErr
        Value
argumentIdValue <-
          Text -> Object -> Maybe Value
JO.lookup Text
"__argument_id__" Object
obj
            Maybe Value -> m Value -> m Value
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` m Value
forall a. m a
throwMissingArgumentIdErr
        Key
argumentId <-
          case Value
argumentIdValue of
            JO.Number Scientific
n ->
              Scientific -> Maybe Key
forall i. (Integral i, Bounded i) => Scientific -> Maybe i
Scientific.toBoundedInteger Scientific
n
                Maybe Key -> m Key -> m Key
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` m Key
forall a. m a
throwInvalidArgumentIdValueErr
            JO.String Text
s ->
              Text -> Maybe Key
forall a. Integral a => Text -> Maybe a
intFromText Text
s
                Maybe Key -> m Key -> m Key
forall (m :: * -> *) a. Applicative m => Maybe a -> m a -> m a
`onNothing` m Key
forall a. m a
throwInvalidArgumentIdValueErr
            Value
_ -> m Key
forall a. m a
throwInvalidArgumentIdValueErr
        (Key, Value) -> m (Key, Value)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Key
argumentId, Value
argumentResult)
      Value
_ -> m (Key, Value)
forall a. m a
throwNoNestedObjectErr
    Value
_ -> m (IntMap Value)
forall a. m a
throwNoListOfObjectsErr
  where
    intFromText :: Text -> Maybe a
intFromText Text
txt = case Reader a
forall a. Integral a => Reader a
TR.decimal Text
txt of
      Right (a
i, Text
"") -> a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
i
      Either String (a, Text)
_ -> Maybe a
forall a. Maybe a
Nothing
    throwInvalidJsonErr :: Text -> m a
throwInvalidJsonErr Text
errMsg =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
        Text
"failed to decode JSON response from the source: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
errMsg
    throwMissingRelationshipDataErr :: m a
throwMissingRelationshipDataErr =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
        Text
"cannot find relationship data (aliased as 'f') within the source \
        \response"
    throwMissingArgumentIdErr :: m a
throwMissingArgumentIdErr =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
        Text
"cannot find '__argument_id__' within the source response"
    throwInvalidArgumentIdValueErr :: m a
throwInvalidArgumentIdValueErr =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$ Text
"expected 'argument_id' to get parsed as backend integer type"
    throwNoNestedObjectErr :: m a
throwNoNestedObjectErr =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
        Text
"expected an object one level deep in the remote schema's response, \
        \but found an array/scalar value instead"
    throwNoListOfObjectsErr :: m a
throwNoListOfObjectsErr =
      Text -> m a
forall (m :: * -> *) a. QErrM m => Text -> m a
throw500 (Text -> m a) -> Text -> m a
forall a b. (a -> b) -> a -> b
$
        Text
"expected a list of objects in the remote schema's response, but found \
        \an object/scalar value instead"