-- | Internal utils.
module Numeric.Data.Internal.Utils
  ( -- * Safe modular arithmetic
    checkModBound,
    modSafeAdd,
    modSafeMult,
    modSafeSub,

    -- * Optics
    rmatching,
  )
where

import Data.Bounds
  ( AnyLowerBounded (someLowerBound),
    AnyUpperBounded (someUpperBound),
  )
import Data.Typeable (Typeable)
import Data.Typeable qualified as Typeable
import Optics.Core
  ( An_AffineTraversal,
    Is,
    NoIx,
    Optic,
    ReversibleOptic (ReversedOptic),
    matching,
    re,
  )

-- | Verifies that the type A is large enough to fit the modulus.
-- Returns 'Nothing' if the check succeeds or a String error message if
-- the check fails.
checkModBound ::
  forall a.
  ( AnyUpperBounded a,
    Integral a,
    Typeable a
  ) =>
  -- | The type /a/ whose upper bound must be large enough to accommodate
  -- modular arithmetic within modulus /n/. This would be a Proxy except we get
  -- a better error message for the value itself.
  a ->
  -- | The modulus n that should satisfy @n <= max(a)@.
  Integer ->
  Maybe String
checkModBound :: forall a.
(AnyUpperBounded a, Integral a, Typeable a) =>
a -> Integer -> Maybe String
checkModBound a
aTerm Integer
modulus =
  forall a. AnyUpperBounded a => Maybe a
someUpperBound @a Maybe a -> (a -> Maybe String) -> Maybe String
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
maxA ->
    let maxAInt :: Integer
maxAInt = a -> Integer
forall a. Integral a => a -> Integer
toInteger a
maxA
     in if Integer
maxS Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAInt
          then Maybe String
forall a. Maybe a
Nothing
          else
            String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$
              [String] -> String
forall a. Monoid a => [a] -> a
mconcat
                [ String
"Type '",
                  TypeRep -> String
forall a. Show a => a -> String
show TypeRep
typeA,
                  String
"' has a maximum size of ",
                  Integer -> String
forall a. Show a => a -> String
show Integer
maxAInt,
                  String
". This is not large enough to safely implement mod ",
                  Integer -> String
forall a. Show a => a -> String
show Integer
modulus,
                  String
"."
                ]
  where
    -- This should ostensibly be modulus - 1 since the highest value in
    -- Z/nZ is (n - 1). But we need to actually perform mod n, hence,
    -- type A must be >= modulus itself.
    maxS :: Integer
maxS = Integer
modulus
    typeA :: TypeRep
typeA = a -> TypeRep
forall a. Typeable a => a -> TypeRep
Typeable.typeOf a
aTerm

-- | Performs modular addition, accounting for rounding in the type itself.
modSafeAdd ::
  forall a.
  ( AnyUpperBounded a,
    Integral a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeAdd :: forall a. (AnyUpperBounded a, Integral a) => a -> a -> a -> a
modSafeAdd = (forall x. Integral x => x -> x -> x) -> a -> a -> a -> a
forall a.
(AnyUpperBounded a, Integral a) =>
(forall x. Integral x => x -> x -> x) -> a -> a -> a -> a
modSafeInc x -> x -> x
forall x. Integral x => x -> x -> x
forall a. Num a => a -> a -> a
(+)

-- | Performs modular multiplication, accounting for rounding in the type
-- itself.
modSafeMult ::
  forall a.
  ( AnyUpperBounded a,
    Integral a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeMult :: forall a. (AnyUpperBounded a, Integral a) => a -> a -> a -> a
modSafeMult = (forall x. Integral x => x -> x -> x) -> a -> a -> a -> a
forall a.
(AnyUpperBounded a, Integral a) =>
(forall x. Integral x => x -> x -> x) -> a -> a -> a -> a
modSafeInc x -> x -> x
forall x. Integral x => x -> x -> x
forall a. Num a => a -> a -> a
(*)

modSafeInc ::
  forall a.
  ( AnyUpperBounded a,
    Integral a
  ) =>
  -- | Operations (addition or multiplication)
  (forall x. (Integral x) => x -> x -> x) ->
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeInc :: forall a.
(AnyUpperBounded a, Integral a) =>
(forall x. Integral x => x -> x -> x) -> a -> a -> a -> a
modSafeInc forall x. Integral x => x -> x -> x
op a
x a
y a
modulus = case forall a. AnyUpperBounded a => Maybe a
someUpperBound @a of
  -- 1. A is unbounded: Easy
  Maybe a
Nothing -> (a
x a -> a -> a
forall x. Integral x => x -> x -> x
`op` a
y) a -> a -> a
forall x. Integral x => x -> x -> x
`mod` a
modulus
  Just a
maxA ->
    let maxAInt :: Integer
maxAInt = a -> Integer
aToInteger a
maxA
        resultInt :: Integer
resultInt = a -> Integer
aToInteger a
x Integer -> Integer -> Integer
forall x. Integral x => x -> x -> x
`op` a -> Integer
aToInteger a
y
     in if Integer
resultInt Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
maxAInt
          then -- 2. A is bounded but the result fits within the bound:
          -- No problem, just convert and reduce.
            Integer -> a
integerToA Integer
resultInt a -> a -> a
forall x. Integral x => x -> x -> x
`mod` a
modulus
          else -- 3. Result does not fit within A. Do the modular arithmetic
          -- in Integer instead, converting the result. Note that this assumes
          -- that the final result fits within A.
            let modulusInt :: Integer
modulusInt = a -> Integer
aToInteger a
modulus
             in Integer -> a
integerToA (Integer
resultInt Integer -> Integer -> Integer
forall x. Integral x => x -> x -> x
`mod` Integer
modulusInt)
  where
    aToInteger :: a -> Integer
    aToInteger :: a -> Integer
aToInteger = a -> Integer
forall a. Integral a => a -> Integer
toInteger

    integerToA :: Integer -> a
    integerToA :: Integer -> a
integerToA = Integer -> a
forall a. Num a => Integer -> a
fromInteger

-- | Performs modular subtraction, accounting for rounding in the type
-- itself.
modSafeSub ::
  forall a.
  ( AnyLowerBounded a,
    Integral a
  ) =>
  -- | x
  a ->
  -- | y
  a ->
  -- | n (modulus)
  a ->
  a
modSafeSub :: forall a. (AnyLowerBounded a, Integral a) => a -> a -> a -> a
modSafeSub a
x a
y a
modulus = case forall a. AnyLowerBounded a => Maybe a
someLowerBound @a of
  -- 1. A is unbounded: Easy
  Maybe a
Nothing -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y) a -> a -> a
forall x. Integral x => x -> x -> x
`mod` a
modulus
  Just a
minA ->
    let minAInt :: Integer
minAInt = a -> Integer
aToInteger a
minA
        diffInt :: Integer
diffInt = a -> Integer
aToInteger a
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- a -> Integer
aToInteger a
y
     in if Integer
diffInt Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
minAInt
          then -- 2. A is bounded but the result fits within the bound:
          -- No problem, just convert and reduce.
            Integer -> a
integerToA Integer
diffInt a -> a -> a
forall x. Integral x => x -> x -> x
`mod` a
modulus
          else -- 3. Result does not fit within A. Do the modular arithmetic
          -- in Integer instead, converting the result. Note that this assumes
          -- that the final result fits within A. else modulus - y + x
            let modulusInt :: Integer
modulusInt = a -> Integer
aToInteger a
modulus
             in Integer -> a
integerToA (Integer
diffInt Integer -> Integer -> Integer
forall x. Integral x => x -> x -> x
`mod` Integer
modulusInt)
  where
    aToInteger :: a -> Integer
    aToInteger :: a -> Integer
aToInteger = a -> Integer
forall a. Integral a => a -> Integer
toInteger

    integerToA :: Integer -> a
    integerToA :: Integer -> a
integerToA = Integer -> a
forall a. Num a => Integer -> a
fromInteger

-- | Reversed 'matching'. Useful with smart-constructor optics.
--
-- @since 0.1
rmatching ::
  (Is (ReversedOptic k) An_AffineTraversal, ReversibleOptic k) =>
  Optic k NoIx b a t s ->
  s ->
  Either t a
rmatching :: forall k b a t s.
(Is (ReversedOptic k) An_AffineTraversal, ReversibleOptic k) =>
Optic k NoIx b a t s -> s -> Either t a
rmatching = Optic (ReversedOptic k) NoIx s t a b -> s -> Either t a
forall k (is :: IxList) s t a b.
Is k An_AffineTraversal =>
Optic k is s t a b -> s -> Either t a
matching (Optic (ReversedOptic k) NoIx s t a b -> s -> Either t a)
-> (Optic k NoIx b a t s -> Optic (ReversedOptic k) NoIx s t a b)
-> Optic k NoIx b a t s
-> s
-> Either t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Optic k NoIx b a t s -> Optic (ReversedOptic k) NoIx s t a b
forall (is :: IxList) s t a b.
AcceptsEmptyIndices "re" is =>
Optic k is s t a b -> Optic (ReversedOptic k) is b a t s
forall k (is :: IxList) s t a b.
(ReversibleOptic k, AcceptsEmptyIndices "re" is) =>
Optic k is s t a b -> Optic (ReversedOptic k) is b a t s
re
{-# INLINEABLE rmatching #-}