module Data.HashMap.Strict.Extended
  ( module M,
    fromListOn,
    groupOn,
    groupOnNE,
    differenceOn,
    insertWithM,
    isInverseOf,
    unionWithM,
    unionsAll,
    homogenise,
  )
where

import Control.Monad (foldM)
import Data.Foldable qualified as F
import Data.Function (on)
import Data.HashMap.Strict as M
import Data.HashSet (HashSet)
import Data.HashSet qualified as S
import Data.Hashable (Hashable)
import Data.List qualified as L
import Data.List.NonEmpty (NonEmpty (..))
import Prelude

fromListOn :: (Eq k, Hashable k) => (v -> k) -> [v] -> HashMap k v
fromListOn :: (v -> k) -> [v] -> HashMap k v
fromListOn v -> k
f = [(k, v)] -> HashMap k v
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList ([(k, v)] -> HashMap k v)
-> ([v] -> [(k, v)]) -> [v] -> HashMap k v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v -> (k, v)) -> [v] -> [(k, v)]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (\v
v -> (v -> k
f v
v, v
v))

-- | Given a 'Foldable' sequence of values and a function that extracts a key from each value,
-- returns a 'HashMap' that maps each key to a list of all values in the sequence for which the
-- given function produced it.
--
-- >>> groupOn (take 1) ["foo", "bar", "baz"]
-- fromList [("f", ["foo"]), ("b", ["bar", "baz"])]
groupOn :: (Eq k, Hashable k, Foldable t) => (v -> k) -> t v -> HashMap k [v]
groupOn :: (v -> k) -> t v -> HashMap k [v]
groupOn v -> k
f = (NonEmpty v -> [v]) -> HashMap k (NonEmpty v) -> HashMap k [v]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NonEmpty v -> [v]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList (HashMap k (NonEmpty v) -> HashMap k [v])
-> (t v -> HashMap k (NonEmpty v)) -> t v -> HashMap k [v]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v -> k) -> t v -> HashMap k (NonEmpty v)
forall k (t :: * -> *) v.
(Eq k, Hashable k, Foldable t) =>
(v -> k) -> t v -> HashMap k (NonEmpty v)
groupOnNE v -> k
f

groupOnNE ::
  (Eq k, Hashable k, Foldable t) => (v -> k) -> t v -> HashMap k (NonEmpty v)
groupOnNE :: (v -> k) -> t v -> HashMap k (NonEmpty v)
groupOnNE v -> k
f =
  (v -> HashMap k (NonEmpty v) -> HashMap k (NonEmpty v))
-> HashMap k (NonEmpty v) -> t v -> HashMap k (NonEmpty v)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
Prelude.foldr
    (\v
v -> (Maybe (NonEmpty v) -> Maybe (NonEmpty v))
-> k -> HashMap k (NonEmpty v) -> HashMap k (NonEmpty v)
forall k v.
(Eq k, Hashable k) =>
(Maybe v -> Maybe v) -> k -> HashMap k v -> HashMap k v
M.alter (NonEmpty v -> Maybe (NonEmpty v)
forall a. a -> Maybe a
Just (NonEmpty v -> Maybe (NonEmpty v))
-> (Maybe (NonEmpty v) -> NonEmpty v)
-> Maybe (NonEmpty v)
-> Maybe (NonEmpty v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (v
v v -> [v] -> NonEmpty v
forall a. a -> [a] -> NonEmpty a
:|) ([v] -> NonEmpty v)
-> (Maybe (NonEmpty v) -> [v]) -> Maybe (NonEmpty v) -> NonEmpty v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [v] -> (NonEmpty v -> [v]) -> Maybe (NonEmpty v) -> [v]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] NonEmpty v -> [v]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList) (v -> k
f v
v))
    HashMap k (NonEmpty v)
forall k v. HashMap k v
M.empty

differenceOn ::
  (Eq k, Hashable k, Foldable t) => (v -> k) -> t v -> t v -> HashMap k v
differenceOn :: (v -> k) -> t v -> t v -> HashMap k v
differenceOn v -> k
f = HashMap k v -> HashMap k v -> HashMap k v
forall k v w.
(Eq k, Hashable k) =>
HashMap k v -> HashMap k w -> HashMap k v
M.difference (HashMap k v -> HashMap k v -> HashMap k v)
-> (t v -> HashMap k v) -> t v -> t v -> HashMap k v
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ((v -> k) -> [v] -> HashMap k v
forall k v. (Eq k, Hashable k) => (v -> k) -> [v] -> HashMap k v
fromListOn v -> k
f ([v] -> HashMap k v) -> (t v -> [v]) -> t v -> HashMap k v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t v -> [v]
forall (t :: * -> *) a. Foldable t => t a -> [a]
F.toList)

-- | Monadic version of https://hackage.haskell.org/package/unordered-containers-0.2.18.0/docs/Data-HashMap-Internal.html#v:insertWith
insertWithM :: (Monad m, Hashable k, Eq k) => (v -> v -> m v) -> k -> v -> HashMap k v -> m (HashMap k v)
insertWithM :: (v -> v -> m v) -> k -> v -> HashMap k v -> m (HashMap k v)
insertWithM v -> v -> m v
f k
k v
v HashMap k v
m =
  HashMap k (m v) -> m (HashMap k v)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence (HashMap k (m v) -> m (HashMap k v))
-> HashMap k (m v) -> m (HashMap k v)
forall a b. (a -> b) -> a -> b
$
    (m v -> m v -> m v)
-> k -> m v -> HashMap k (m v) -> HashMap k (m v)
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> k -> v -> HashMap k v -> HashMap k v
M.insertWith
      ( \m v
a m v
b -> do
          v
x <- m v
a
          v
y <- m v
b
          v -> v -> m v
f v
x v
y
      )
      k
k
      (v -> m v
forall (m :: * -> *) a. Monad m => a -> m a
return v
v)
      (v -> m v
forall (m :: * -> *) a. Monad m => a -> m a
return (v -> m v) -> HashMap k v -> HashMap k (m v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HashMap k v
m)

-- | Determines whether the left-hand-side and the right-hand-side are inverses of each other.
--
-- More specifically, for two maps @A@ and @B@, 'isInverseOf' is satisfied when both of the
-- following are true:
-- 1. @∀ key ∈ A. A[key] ∈  B ∧ B[A[key]] == key@
-- 2. @∀ key ∈ B. B[key] ∈  A ∧ A[B[key]] == key@
isInverseOf ::
  (Eq k, Hashable k, Eq v, Hashable v) => HashMap k v -> HashMap v k -> Bool
HashMap k v
lhs isInverseOf :: HashMap k v -> HashMap v k -> Bool
`isInverseOf` HashMap v k
rhs = HashMap k v
lhs HashMap k v -> HashMap v k -> Bool
forall s t.
(Eq s, Eq t, Hashable t) =>
HashMap s t -> HashMap t s -> Bool
`invertedBy` HashMap v k
rhs Bool -> Bool -> Bool
&& HashMap v k
rhs HashMap v k -> HashMap k v -> Bool
forall s t.
(Eq s, Eq t, Hashable t) =>
HashMap s t -> HashMap t s -> Bool
`invertedBy` HashMap k v
lhs
  where
    invertedBy ::
      forall s t.
      (Eq s, Eq t, Hashable t) =>
      HashMap s t ->
      HashMap t s ->
      Bool
    HashMap s t
a invertedBy :: HashMap s t -> HashMap t s -> Bool
`invertedBy` HashMap t s
b = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ do
      (s
k, t
v) <- HashMap s t -> [(s, t)]
forall k v. HashMap k v -> [(k, v)]
M.toList HashMap s t
a
      Bool -> [Bool]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> [Bool]) -> Bool -> [Bool]
forall a b. (a -> b) -> a -> b
$ t -> HashMap t s -> Maybe s
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup t
v HashMap t s
b Maybe s -> Maybe s -> Bool
forall a. Eq a => a -> a -> Bool
== s -> Maybe s
forall a. a -> Maybe a
Just s
k

-- | The union of two maps.
--
-- If a key occurs in both maps, the provided function (first argument) will be
-- used to compute the result. Unlike 'unionWith', 'unionWithM' performs the
-- computation in an arbitratry monad.
unionWithM ::
  (Monad m, Eq k, Hashable k) =>
  (k -> v -> v -> m v) ->
  HashMap k v ->
  HashMap k v ->
  m (HashMap k v)
unionWithM :: (k -> v -> v -> m v)
-> HashMap k v -> HashMap k v -> m (HashMap k v)
unionWithM k -> v -> v -> m v
f HashMap k v
m1 HashMap k v
m2 = (HashMap k v -> (k, v) -> m (HashMap k v))
-> HashMap k v -> [(k, v)] -> m (HashMap k v)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM HashMap k v -> (k, v) -> m (HashMap k v)
step HashMap k v
m1 (HashMap k v -> [(k, v)]
forall k v. HashMap k v -> [(k, v)]
M.toList HashMap k v
m2)
  where
    step :: HashMap k v -> (k, v) -> m (HashMap k v)
step HashMap k v
m (k
k, v
new) = case k -> HashMap k v -> Maybe v
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup k
k HashMap k v
m of
      Maybe v
Nothing -> HashMap k v -> m (HashMap k v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashMap k v -> m (HashMap k v)) -> HashMap k v -> m (HashMap k v)
forall a b. (a -> b) -> a -> b
$ k -> v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert k
k v
new HashMap k v
m
      Just v
old -> do
        v
combined <- k -> v -> v -> m v
f k
k v
new v
old
        HashMap k v -> m (HashMap k v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HashMap k v -> m (HashMap k v)) -> HashMap k v -> m (HashMap k v)
forall a b. (a -> b) -> a -> b
$ k -> v -> HashMap k v -> HashMap k v
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert k
k v
combined HashMap k v
m

-- | Like 'M.unions', but keeping all elements in the result.
unionsAll ::
  (Eq k, Hashable k, Foldable t) => t (HashMap k v) -> HashMap k (NonEmpty v)
unionsAll :: t (HashMap k v) -> HashMap k (NonEmpty v)
unionsAll = (HashMap k (NonEmpty v) -> HashMap k v -> HashMap k (NonEmpty v))
-> HashMap k (NonEmpty v)
-> t (HashMap k v)
-> HashMap k (NonEmpty v)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' (\HashMap k (NonEmpty v)
a HashMap k v
b -> (NonEmpty v -> NonEmpty v -> NonEmpty v)
-> HashMap k (NonEmpty v)
-> HashMap k (NonEmpty v)
-> HashMap k (NonEmpty v)
forall k v.
(Eq k, Hashable k) =>
(v -> v -> v) -> HashMap k v -> HashMap k v -> HashMap k v
M.unionWith NonEmpty v -> NonEmpty v -> NonEmpty v
forall a. Semigroup a => a -> a -> a
(<>) HashMap k (NonEmpty v)
a ((v -> NonEmpty v) -> HashMap k v -> HashMap k (NonEmpty v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (v -> [v] -> NonEmpty v
forall a. a -> [a] -> NonEmpty a
:| []) HashMap k v
b)) HashMap k (NonEmpty v)
forall k v. HashMap k v
M.empty

-- | Homogenise maps, such that all maps range over the full set of
-- keys, inserting a default value as needed.
homogenise :: (Hashable a, Eq a) => b -> [HashMap a b] -> (HashSet a, [HashMap a b])
homogenise :: b -> [HashMap a b] -> (HashSet a, [HashMap a b])
homogenise b
defaultValue [HashMap a b]
maps =
  let ks :: HashSet a
ks = [HashSet a] -> HashSet a
forall a. (Eq a, Hashable a) => [HashSet a] -> HashSet a
S.unions ([HashSet a] -> HashSet a) -> [HashSet a] -> HashSet a
forall a b. (a -> b) -> a -> b
$ (HashMap a b -> HashSet a) -> [HashMap a b] -> [HashSet a]
forall a b. (a -> b) -> [a] -> [b]
L.map HashMap a b -> HashSet a
forall k a. HashMap k a -> HashSet k
M.keysSet [HashMap a b]
maps
      defaults :: HashMap a b
defaults = [(a, b)] -> HashMap a b
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList [(a
k, b
defaultValue) | a
k <- HashSet a -> [a]
forall a. HashSet a -> [a]
S.toList HashSet a
ks]
   in (HashSet a
ks, (HashMap a b -> HashMap a b) -> [HashMap a b] -> [HashMap a b]
forall a b. (a -> b) -> [a] -> [b]
L.map (HashMap a b -> HashMap a b -> HashMap a b
forall a. Semigroup a => a -> a -> a
<> HashMap a b
defaults) [HashMap a b]
maps)