module Hasura.Server.Auth.JWT.Internal
  ( parseEdDSAKey,
    parseHmacKey,
    parseRsaKey,
  )
where

import Control.Lens
import Crypto.JOSE.Types (Base64Integer (..))
import Crypto.JWT
import Crypto.PubKey.RSA (PublicKey (..))
import Data.ASN1.BinaryEncoding (DER (..))
import Data.ASN1.Encoding (decodeASN1')
import Data.ASN1.Types
  ( ASN1 (End, IntVal, Start),
    ASN1ConstructionType (Sequence),
    fromASN1,
  )
import Data.ByteString.Lazy qualified as BL
import Data.Int (Int64)
import Data.PEM qualified as PEM
import Data.Text qualified as T
import Data.Text.Conversions
import Data.X509 qualified as X509
import Hasura.Prelude
import Hasura.Server.Utils (fmapL)

-- | Helper functions to decode Text to JWK
parseHmacKey :: Text -> Int64 -> Either Text JWK
parseHmacKey :: Text -> Int64 -> Either Text JWK
parseHmacKey Text
key Int64
size = do
  let secret :: ByteString
secret = UTF8 ByteString -> ByteString
forall a. UTF8 a -> a
unUTF8 (UTF8 ByteString -> ByteString) -> UTF8 ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> UTF8 ByteString
forall a. FromText a => Text -> a
fromText Text
key
      err :: a -> String
err a
s = String
"Key size too small; should be atleast " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show (a
s a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
8) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" characters"
  if ByteString -> Int64
BL.length ByteString
secret Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
size Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`div` Int64
8
    then Text -> Either Text JWK
forall a b. a -> Either a b
Left (Text -> Either Text JWK)
-> (String -> Text) -> String -> Either Text JWK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack (String -> Either Text JWK) -> String -> Either Text JWK
forall a b. (a -> b) -> a -> b
$ Int64 -> String
forall a. (Show a, Integral a) => a -> String
err Int64
size
    else JWK -> Either Text JWK
forall (f :: * -> *) a. Applicative f => a -> f a
pure (JWK -> Either Text JWK) -> JWK -> Either Text JWK
forall a b. (a -> b) -> a -> b
$ ByteString -> JWK
forall s. Cons s s Word8 Word8 => s -> JWK
fromOctets ByteString
secret

parseRsaKey :: Text -> Either Text JWK
parseRsaKey :: Text -> Either Text JWK
parseRsaKey Text
key = do
  let res :: Either Text JWK
res = ByteString -> Either Text JWK
fromRawPem (UTF8 ByteString -> ByteString
forall a. UTF8 a -> a
unUTF8 (UTF8 ByteString -> ByteString) -> UTF8 ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> UTF8 ByteString
forall a. FromText a => Text -> a
fromText Text
key)
      err :: a -> a
err a
e = a
"Could not decode PEM: " a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
e
  Either Text JWK -> (Text -> Either Text JWK) -> Either Text JWK
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft Either Text JWK
res (Text -> Either Text JWK
forall a b. a -> Either a b
Left (Text -> Either Text JWK)
-> (Text -> Text) -> Text -> Either Text JWK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
forall a. (Semigroup a, IsString a) => a -> a
err)

parseEdDSAKey :: Text -> Either Text JWK
parseEdDSAKey :: Text -> Either Text JWK
parseEdDSAKey Text
key = do
  let res :: Either Text JWK
res = ByteString -> Either Text JWK
fromRawPem (UTF8 ByteString -> ByteString
forall a. UTF8 a -> a
unUTF8 (UTF8 ByteString -> ByteString) -> UTF8 ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Text -> UTF8 ByteString
forall a. FromText a => Text -> a
fromText Text
key)
      err :: a -> a
err a
e = a
"Could not decode PEM: " a -> a -> a
forall a. Semigroup a => a -> a -> a
<> a
e
  Either Text JWK -> (Text -> Either Text JWK) -> Either Text JWK
forall (m :: * -> *) e a.
Applicative m =>
Either e a -> (e -> m a) -> m a
onLeft Either Text JWK
res (Text -> Either Text JWK
forall a b. a -> Either a b
Left (Text -> Either Text JWK)
-> (Text -> Text) -> Text -> Either Text JWK
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
forall a. (Semigroup a, IsString a) => a -> a
err)

-- | Helper functions to decode PEM bytestring to RSA public key

-- try PKCS first, then x509
fromRawPem :: BL.ByteString -> Either Text JWK
fromRawPem :: ByteString -> Either Text JWK
fromRawPem ByteString
bs =
  -- pubKeyToJwk <=< fromPkcsPem
  case ByteString -> Either Text PubKey
fromPkcsPem ByteString
bs of
    Right PubKey
pk -> PubKey -> Either Text JWK
pubKeyToJwk PubKey
pk
    Left Text
e ->
      case ByteString -> Either Text PubKey
fromX509Pem ByteString
bs of
        Right PubKey
pk1 -> PubKey -> Either Text JWK
pubKeyToJwk PubKey
pk1
        Left Text
e1 -> Text -> Either Text JWK
forall a b. a -> Either a b
Left (Text
e Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
e1)

-- decode a PKCS1 or PKCS8 PEM to obtain the public key
fromPkcsPem :: BL.ByteString -> Either Text X509.PubKey
fromPkcsPem :: ByteString -> Either Text PubKey
fromPkcsPem ByteString
bs = do
  [PEM]
pems <- (String -> Text) -> Either String [PEM] -> Either Text [PEM]
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL String -> Text
T.pack (Either String [PEM] -> Either Text [PEM])
-> Either String [PEM] -> Either Text [PEM]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String [PEM]
PEM.pemParseLBS ByteString
bs
  PEM
pem <- Text -> [PEM] -> Either Text PEM
forall a. Text -> [a] -> Either Text a
getAtleastOne Text
"No pem found" [PEM]
pems
  [ASN1]
res <- (ASN1Error -> Text)
-> Either ASN1Error [ASN1] -> Either Text [ASN1]
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL ASN1Error -> Text
forall a. Show a => a -> Text
tshow (Either ASN1Error [ASN1] -> Either Text [ASN1])
-> Either ASN1Error [ASN1] -> Either Text [ASN1]
forall a b. (a -> b) -> a -> b
$ DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
decodeASN1' DER
DER (ByteString -> Either ASN1Error [ASN1])
-> ByteString -> Either ASN1Error [ASN1]
forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
PEM.pemContent PEM
pem
  case [ASN1]
res of
    -- PKCS#1 format
    [Start ASN1ConstructionType
Sequence, IntVal Integer
n, IntVal Integer
e, End ASN1ConstructionType
Sequence] ->
      PubKey -> Either Text PubKey
forall (m :: * -> *) a. Monad m => a -> m a
return (PubKey -> Either Text PubKey) -> PubKey -> Either Text PubKey
forall a b. (a -> b) -> a -> b
$ PublicKey -> PubKey
X509.PubKeyRSA (PublicKey -> PubKey) -> PublicKey -> PubKey
forall a b. (a -> b) -> a -> b
$ Int -> Integer -> Integer -> PublicKey
PublicKey (Integer -> Int
calculateSize Integer
n) Integer
n Integer
e
    -- try and see if its a PKCS#8 format
    [ASN1]
asn1 -> do
      (PubKey
pub, [ASN1]
xs) <- (String -> Text)
-> Either String (PubKey, [ASN1]) -> Either Text (PubKey, [ASN1])
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL String -> Text
T.pack (Either String (PubKey, [ASN1]) -> Either Text (PubKey, [ASN1]))
-> Either String (PubKey, [ASN1]) -> Either Text (PubKey, [ASN1])
forall a b. (a -> b) -> a -> b
$ [ASN1] -> Either String (PubKey, [ASN1])
forall a. ASN1Object a => [ASN1] -> Either String (a, [ASN1])
fromASN1 [ASN1]
asn1
      Bool -> Either Text () -> Either Text ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([ASN1] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ASN1]
xs) (Text -> Either Text ()
forall a b. a -> Either a b
Left Text
"Could not decode public key")
      PubKey -> Either Text PubKey
forall (m :: * -> *) a. Monad m => a -> m a
return PubKey
pub

-- decode a x509 certificate containing the RSA public key or EdDSA (ed25519) public key
fromX509Pem :: BL.ByteString -> Either Text X509.PubKey
fromX509Pem :: ByteString -> Either Text PubKey
fromX509Pem ByteString
s = do
  -- try to parse bytestring to a [PEM]
  [PEM]
pems <- (String -> Text) -> Either String [PEM] -> Either Text [PEM]
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL String -> Text
T.pack (Either String [PEM] -> Either Text [PEM])
-> Either String [PEM] -> Either Text [PEM]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String [PEM]
PEM.pemParseLBS ByteString
s
  -- fail if [PEM] is empty
  PEM
pem <- Text -> [PEM] -> Either Text PEM
forall a. Text -> [a] -> Either Text a
getAtleastOne Text
"No pem found" [PEM]
pems
  -- decode the bytestring to a certificate
  SignedCertificate
signedExactCert <-
    (String -> Text)
-> Either String SignedCertificate -> Either Text SignedCertificate
forall a a' b. (a -> a') -> Either a b -> Either a' b
fmapL String -> Text
T.pack (Either String SignedCertificate -> Either Text SignedCertificate)
-> Either String SignedCertificate -> Either Text SignedCertificate
forall a b. (a -> b) -> a -> b
$
      ByteString -> Either String SignedCertificate
X509.decodeSignedCertificate (ByteString -> Either String SignedCertificate)
-> ByteString -> Either String SignedCertificate
forall a b. (a -> b) -> a -> b
$
        PEM -> ByteString
PEM.pemContent PEM
pem
  let cert :: Certificate
cert = Signed Certificate -> Certificate
forall a. (Show a, Eq a, ASN1Object a) => Signed a -> a
X509.signedObject (Signed Certificate -> Certificate)
-> Signed Certificate -> Certificate
forall a b. (a -> b) -> a -> b
$ SignedCertificate -> Signed Certificate
forall a. (Show a, Eq a, ASN1Object a) => SignedExact a -> Signed a
X509.getSigned SignedCertificate
signedExactCert
      pubKey :: PubKey
pubKey = Certificate -> PubKey
X509.certPubKey Certificate
cert
  case PubKey
pubKey of
    X509.PubKeyRSA PublicKey
pk -> PubKey -> Either Text PubKey
forall (m :: * -> *) a. Monad m => a -> m a
return (PubKey -> Either Text PubKey) -> PubKey -> Either Text PubKey
forall a b. (a -> b) -> a -> b
$ PublicKey -> PubKey
X509.PubKeyRSA PublicKey
pk
    X509.PubKeyEd25519 PublicKey
pk -> PubKey -> Either Text PubKey
forall (m :: * -> *) a. Monad m => a -> m a
return (PubKey -> Either Text PubKey) -> PubKey -> Either Text PubKey
forall a b. (a -> b) -> a -> b
$ PublicKey -> PubKey
X509.PubKeyEd25519 PublicKey
pk
    PubKey
_ -> Text -> Either Text PubKey
forall a b. a -> Either a b
Left Text
"Could not decode RSA or EdDSA public key from x509 cert"

pubKeyToJwk :: X509.PubKey -> Either Text JWK
pubKeyToJwk :: PubKey -> Either Text JWK
pubKeyToJwk PubKey
pubKey = do
  JWK
jwk' <- Either Text JWK
mkJwk
  JWK -> Either Text JWK
forall (m :: * -> *) a. Monad m => a -> m a
return (JWK -> Either Text JWK) -> JWK -> Either Text JWK
forall a b. (a -> b) -> a -> b
$ JWK
jwk' JWK -> (JWK -> JWK) -> JWK
forall a b. a -> (a -> b) -> b
& (Maybe [KeyOp] -> Identity (Maybe [KeyOp])) -> JWK -> Identity JWK
Lens' JWK (Maybe [KeyOp])
jwkKeyOps ((Maybe [KeyOp] -> Identity (Maybe [KeyOp]))
 -> JWK -> Identity JWK)
-> Maybe [KeyOp] -> JWK -> JWK
forall s t a b. ASetter s t a b -> b -> s -> t
.~ [KeyOp] -> Maybe [KeyOp]
forall a. a -> Maybe a
Just [KeyOp
Verify]
  where
    mkJwk :: Either Text JWK
mkJwk = case PubKey
pubKey of
      X509.PubKeyRSA (PublicKey Int
_ Integer
n Integer
e) ->
        JWK -> Either Text JWK
forall (m :: * -> *) a. Monad m => a -> m a
return (JWK -> Either Text JWK) -> JWK -> Either Text JWK
forall a b. (a -> b) -> a -> b
$ KeyMaterial -> JWK
fromKeyMaterial (KeyMaterial -> JWK) -> KeyMaterial -> JWK
forall a b. (a -> b) -> a -> b
$ RSAKeyParameters -> KeyMaterial
RSAKeyMaterial (Integer -> Integer -> RSAKeyParameters
rsaKeyParams Integer
n Integer
e)
      X509.PubKeyEd25519 PublicKey
pubKeyEd ->
        JWK -> Either Text JWK
forall (m :: * -> *) a. Monad m => a -> m a
return (JWK -> Either Text JWK) -> JWK -> Either Text JWK
forall a b. (a -> b) -> a -> b
$ KeyMaterial -> JWK
fromKeyMaterial (KeyMaterial -> JWK) -> KeyMaterial -> JWK
forall a b. (a -> b) -> a -> b
$ OKPKeyParameters -> KeyMaterial
OKPKeyMaterial (PublicKey -> Maybe SecretKey -> OKPKeyParameters
Ed25519Key PublicKey
pubKeyEd Maybe SecretKey
forall a. Maybe a
Nothing)
      PubKey
_ -> Text -> Either Text JWK
forall a b. a -> Either a b
Left Text
"This key type is not supported"
    rsaKeyParams :: Integer -> Integer -> RSAKeyParameters
rsaKeyParams Integer
n Integer
e =
      Base64Integer
-> Base64Integer
-> Maybe RSAPrivateKeyParameters
-> RSAKeyParameters
RSAKeyParameters (Integer -> Base64Integer
Base64Integer Integer
n) (Integer -> Base64Integer
Base64Integer Integer
e) Maybe RSAPrivateKeyParameters
forall a. Maybe a
Nothing

getAtleastOne :: Text -> [a] -> Either Text a
getAtleastOne :: Text -> [a] -> Either Text a
getAtleastOne Text
err [] = Text -> Either Text a
forall a b. a -> Either a b
Left Text
err
getAtleastOne Text
_ (a
x : [a]
_) = a -> Either Text a
forall a b. b -> Either a b
Right a
x

calculateSize :: Integer -> Int
calculateSize :: Integer -> Int
calculateSize = Int -> Integer -> Int
forall t t. (Integral t, Num t, Ord t) => t -> t -> t
go Int
1
  where
    go :: t -> t -> t
go t
i t
n
      | t
2 t -> t -> t
forall a b. (Num a, Integral b) => a -> b -> a
^ (t
i t -> t -> t
forall a. Num a => a -> a -> a
* t
8) t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
n = t
i
      | Bool
otherwise = t -> t -> t
go (t
i t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) t
n