-- | Internal utils.
module Numeric.Data.Internal.Utils
  ( -- * Safe modular arithmetic

    -- ** Algebra-simple
    checkModBoundAlgebra,
    modSafeAddAlgebra,
    modSafeMultAlgebra,
    modSafeSubAlgebra,

    -- * Optics
    rmatching,
  )
where

import Data.Bounds (MaybeUpperBounded (maybeUpperBound))
import Data.Typeable (Typeable)
import Data.Typeable qualified as Typeable
import Numeric.Algebra
  ( ASemigroup ((.+.)),
    MEuclidean,
    MSemigroup ((.*.)),
    mmod,
  )
import Numeric.Convert.Integer (FromInteger (fromZ), ToInteger (toZ))
import Optics.Core
  ( An_AffineTraversal,
    Is,
    NoIx,
    Optic,
    ReversibleOptic (ReversedOptic),
    matching,
    re,
  )

-- | Verifies that the type A is large enough to fit the modulus.
-- Returns 'Nothing' if the check succeeds or a String error message if
-- the check fails.
checkModBoundAlgebra ::
  forall a.
  ( ToInteger a,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  -- | The type /a/ whose upper bound must be large enough to accommodate
  -- modular arithmetic within modulus /n/. This would be a Proxy except we get
  -- a better error message for the value itself.
  a ->
  -- | The modulus n that should satisfy @n <= max(a)@.
  Integer ->
  Maybe String
checkModBoundAlgebra :: forall a.
(ToInteger a, MaybeUpperBounded a, Typeable a) =>
a -> Integer -> Maybe String
checkModBoundAlgebra a
aTerm Integer
modulus =
  forall a. MaybeUpperBounded a => Maybe a
maybeUpperBound @a Maybe a -> (a -> Maybe String) -> Maybe String
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
maxA ->
    let maxAℤ :: Integer
maxAℤ = a -> Integer
forall a. (ToInteger a, HasCallStack) => a -> Integer
toZ a
maxA
     in if Integer
maxS Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAℤ
          then Maybe String
forall a. Maybe a
Nothing
          else
            String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$
              [String] -> String
forall a. Monoid a => [a] -> a
mconcat
                [ String
"Type '",
                  TypeRep -> String
forall a. Show a => a -> String
show TypeRep
typeA,
                  String
"' has a maximum size of ",
                  Integer -> String
forall a. Show a => a -> String
show Integer
maxAℤ,
                  String
". This is not large enough to safely implement mod ",
                  Integer -> String
forall a. Show a => a -> String
show Integer
modulus,
                  String
"."
                ]
  where
    -- This should ostensibly be modulus - 1 since the highest value in
    -- Z/nZ is (n - 1). But we need to actually perform mod n, hence,
    -- type A must be >= modulus itself.
    maxS :: Integer
maxS = Integer
modulus
    typeA :: TypeRep
typeA = a -> TypeRep
forall a. Typeable a => a -> TypeRep
Typeable.typeOf a
aTerm

-- | Performs modular addition, accounting for rounding in the type
-- itself.
modSafeAddAlgebra ::
  forall a.
  ( ASemigroup a,
    FromInteger a,
    MEuclidean a,
    MaybeUpperBounded a,
    ToInteger a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeAddAlgebra :: forall a.
(ASemigroup a, FromInteger a, MEuclidean a, MaybeUpperBounded a,
 ToInteger a) =>
a -> a -> a -> a
modSafeAddAlgebra a
x a
y a
modulus = case forall a. MaybeUpperBounded a => Maybe a
maybeUpperBound @a of
  -- 1. A is unbounded: Easy
  Maybe a
Nothing -> (a
x a -> a -> a
forall s. ASemigroup s => s -> s -> s
.+. a
y) a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
  Just a
maxA ->
    let maxAℤ :: Integer
maxAℤ = a -> Integer
aToℤ a
maxA
        resultℤ :: Integer
resultℤ = a -> Integer
aToℤ a
x Integer -> Integer -> Integer
forall s. ASemigroup s => s -> s -> s
.+. a -> Integer
aToℤ a
y
     in if Integer
resultℤ Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAℤ
          then -- 2. A is bounded but the result fits within the bound:
          -- No problem, just convert and reduce.
            Integer -> a
aFromℤ Integer
resultℤ a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
          else -- 3. Result does not fit within A. Do the modular arithmetic
          -- in Integer instead, converting the result. Note that this assumes
          -- that the final result fits within A.
            let modulusℤ :: Integer
modulusℤ = a -> Integer
aToℤ a
modulus
             in Integer -> a
aFromℤ (Integer
resultℤ Integer -> Integer -> Integer
forall g. MEuclidean g => g -> g -> g
`mmod` Integer
modulusℤ)
  where
    aToℤ :: a -> Integer
    aToℤ :: a -> Integer
aToℤ = a -> Integer
forall a. (ToInteger a, HasCallStack) => a -> Integer
toZ

    aFromℤ :: Integer -> a
    aFromℤ :: Integer -> a
aFromℤ = Integer -> a
forall a. (FromInteger a, HasCallStack) => Integer -> a
fromZ

-- | Performs modular multiplication, accounting for rounding in the type
-- itself.
modSafeMultAlgebra ::
  forall a.
  ( FromInteger a,
    MEuclidean a,
    MaybeUpperBounded a,
    ToInteger a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeMultAlgebra :: forall a.
(FromInteger a, MEuclidean a, MaybeUpperBounded a, ToInteger a) =>
a -> a -> a -> a
modSafeMultAlgebra a
x a
y a
modulus = case forall a. MaybeUpperBounded a => Maybe a
maybeUpperBound @a of
  -- 1. A is unbounded: Easy
  Maybe a
Nothing -> (a
x a -> a -> a
forall s. MSemigroup s => s -> s -> s
.*. a
y) a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
  Just a
maxA ->
    let maxAℤ :: Integer
maxAℤ = a -> Integer
aToℤ a
maxA
        resultℤ :: Integer
resultℤ = a -> Integer
aToℤ a
x Integer -> Integer -> Integer
forall s. MSemigroup s => s -> s -> s
.*. a -> Integer
aToℤ a
y
     in if Integer
resultℤ Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAℤ
          then -- 2. A is bounded but the result fits within the bound:
          -- No problem, just convert and reduce.
            Integer -> a
aFromℤ Integer
resultℤ a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
          else -- 3. Result does not fit within A. Do the modular arithmetic
          -- in Integer instead, converting the result. Note that this assumes
          -- that the final result fits within A.
            let modulusℤ :: Integer
modulusℤ = a -> Integer
aToℤ a
modulus
             in Integer -> a
aFromℤ (Integer
resultℤ Integer -> Integer -> Integer
forall g. MEuclidean g => g -> g -> g
`mmod` Integer
modulusℤ)
  where
    aToℤ :: a -> Integer
    aToℤ :: a -> Integer
aToℤ = a -> Integer
forall a. (ToInteger a, HasCallStack) => a -> Integer
toZ

    aFromℤ :: Integer -> a
    aFromℤ :: Integer -> a
aFromℤ = Integer -> a
forall a. (FromInteger a, HasCallStack) => Integer -> a
fromZ

-- | Performs modular subtraction, accounting for rounding in the type
-- itself.
modSafeSubAlgebra ::
  forall a.
  ( ASemigroup a,
    FromInteger a,
    MEuclidean a,
    ToInteger a,
    MaybeUpperBounded a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeSubAlgebra :: forall a.
(ASemigroup a, FromInteger a, MEuclidean a, ToInteger a,
 MaybeUpperBounded a) =>
a -> a -> a -> a
modSafeSubAlgebra a
x a
y a
modulus = case forall a. MaybeUpperBounded a => Maybe a
maybeUpperBound @a of
  -- 1. A is unbounded: Easy
  Maybe a
Nothing -> (a
x a -> a -> a
forall s. ASemigroup s => s -> s -> s
.+. Integer -> a
forall a. (FromInteger a, HasCallStack) => Integer -> a
fromZ Integer
negYℤ) a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
  Just a
maxA ->
    let maxAℤ :: Integer
maxAℤ = a -> Integer
aToℤ a
maxA
        resultℤ :: Integer
resultℤ = Integer
xℤ Integer -> Integer -> Integer
forall s. ASemigroup s => s -> s -> s
.+. Integer
negYℤ
     in if Integer
resultℤ Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAℤ
          then -- 2. A is bounded but the result fits within the bound:
          -- No problem, just convert and reduce.
            Integer -> a
aFromℤ Integer
resultℤ a -> a -> a
forall g. MEuclidean g => g -> g -> g
`mmod` a
modulus
          else -- 3. Result does not fit within A. Do the modular arithmetic
          -- in Integer instead, converting the result. Note that this assumes
          -- that the final result fits within A.
            Integer -> a
aFromℤ (Integer
resultℤ Integer -> Integer -> Integer
forall g. MEuclidean g => g -> g -> g
`mmod` Integer
modulusℤ)
  where
    xℤ :: Integer
xℤ = a -> Integer
aToℤ a
x
    yℤ :: Integer
yℤ = a -> Integer
aToℤ a
y
    modulusℤ :: Integer
modulusℤ = a -> Integer
aToℤ a
modulus

    negYℤ :: Integer
negYℤ = Integer
modulusℤ Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
yℤ

    aToℤ :: a -> Integer
    aToℤ :: a -> Integer
aToℤ = a -> Integer
forall a. (ToInteger a, HasCallStack) => a -> Integer
toZ

    aFromℤ :: Integer -> a
    aFromℤ :: Integer -> a
aFromℤ = Integer -> a
forall a. (FromInteger a, HasCallStack) => Integer -> a
fromZ

-- | Reversed 'matching'. Useful with smart-constructor optics.
--
-- @since 0.1
rmatching ::
  (Is (ReversedOptic k) An_AffineTraversal, ReversibleOptic k) =>
  Optic k NoIx b a t s ->
  s ->
  Either t a
rmatching :: forall k b a t s.
(Is (ReversedOptic k) An_AffineTraversal, ReversibleOptic k) =>
Optic k NoIx b a t s -> s -> Either t a
rmatching = Optic (ReversedOptic k) NoIx s t a b -> s -> Either t a
forall k (is :: IxList) s t a b.
Is k An_AffineTraversal =>
Optic k is s t a b -> s -> Either t a
matching (Optic (ReversedOptic k) NoIx s t a b -> s -> Either t a)
-> (Optic k NoIx b a t s -> Optic (ReversedOptic k) NoIx s t a b)
-> Optic k NoIx b a t s
-> s
-> Either t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Optic k NoIx b a t s -> Optic (ReversedOptic k) NoIx s t a b
forall (is :: IxList) s t a b.
AcceptsEmptyIndices "re" is =>
Optic k is s t a b -> Optic (ReversedOptic k) is b a t s
forall k (is :: IxList) s t a b.
(ReversibleOptic k, AcceptsEmptyIndices "re" is) =>
Optic k is s t a b -> Optic (ReversedOptic k) is b a t s
re
{-# INLINEABLE rmatching #-}