{-# LANGUAGE UndecidableInstances #-}

-- | Provides the 'ModN' type for modular arithmetic.
--
-- @since 0.1
module Numeric.Data.ModN.Internal
  ( -- * Type
    ModN (MkModN, UnsafeModN),

    -- * Creation
    mkModN,
    unsafeModN,
    reallyUnsafeModN,

    -- * Misc
    errMsg,
  )
where

import Control.DeepSeq (NFData)
import Data.Bounds
  ( LowerBounded,
    MaybeLowerBounded,
    MaybeUpperBounded,
    UpperBounded,
  )
import Data.Kind (Type)
import Data.Proxy (Proxy (Proxy))
import Data.Text.Display (Display (displayBuilder))
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import GHC.Records (HasField (getField))
import GHC.Stack (HasCallStack)
import GHC.TypeNats (KnownNat, Nat, natVal)
import Language.Haskell.TH.Syntax (Lift)
import Numeric.Algebra.Additive.AGroup (AGroup ((.-.)))
import Numeric.Algebra.Additive.AMonoid (AMonoid (zero))
import Numeric.Algebra.Additive.ASemigroup (ASemigroup ((.+.)))
import Numeric.Algebra.Multiplicative.MMonoid (MMonoid (one))
import Numeric.Algebra.Multiplicative.MSemigroup (MSemigroup ((.*.)))
import Numeric.Algebra.Ring (Ring)
import Numeric.Algebra.Semiring (Semiring)
import Numeric.Data.Internal.Utils qualified as Utils
import Numeric.Literal.Integer (FromInteger (afromInteger))
import Optics.Core (A_Getter, LabelOptic (labelOptic), to)

-- $setup
-- >>> import Data.Int (Int8)

-- NOTE: [Safe finite modular rounding]
--
-- When creating a new @ModN n a@, we need to ensure it is safe to do so.
-- That is, n needs to fit within type a, and we want to ensure any
-- mathematical operations (e.g. multiplication) do not wrap due to a being
-- finite. Thus we have two scenarios we need to check:
--
-- 1. When we are creating a brand new @ModN n a@ (i.e. the caller is asking
--    for a specific n but does not yet have their hands on one), we need to
--    check that n is within a. We can use Utils.checkModBound via unsafeModN
--    for this.
--
-- 2. When we are combining two @ModN n a@s (e.g. addition), we have already
--    verified that the first check has passed. But we need to ensure the
--    intermediate result does not under/overflow before performing the mod.
--    We can use Utils's modSafe(Add/Mult/Sub) for this.

-- | Newtype wrapper that represents \( \mathbb{Z}/n\mathbb{Z} \).
-- 'ModN' is a 'Numeric.Algebra.Ring.Ring' i.e. supports addition, subtraction,
-- and multiplication.
--
-- When constructing a @'ModN' n a@ we must verify that the type @a@ is large
-- enough to accommodate @n@, hence the possible failure.
--
-- ==== __Examples__
--
-- >>> import Data.Text.Display (display)
-- >>> display $ unsafeModN @7 10
-- "3 (mod 7)"
--
-- @since 0.1
type ModN :: Nat -> Type -> Type
newtype ModN n a = UnsafeModN a
  deriving stock
    ( -- | @since 0.1
      ModN n a -> ModN n a -> Bool
(ModN n a -> ModN n a -> Bool)
-> (ModN n a -> ModN n a -> Bool) -> Eq (ModN n a)
forall (n :: Nat) a. Eq a => ModN n a -> ModN n a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (n :: Nat) a. Eq a => ModN n a -> ModN n a -> Bool
== :: ModN n a -> ModN n a -> Bool
$c/= :: forall (n :: Nat) a. Eq a => ModN n a -> ModN n a -> Bool
/= :: ModN n a -> ModN n a -> Bool
Eq,
      -- | @since 0.1
      (forall x. ModN n a -> Rep (ModN n a) x)
-> (forall x. Rep (ModN n a) x -> ModN n a) -> Generic (ModN n a)
forall (n :: Nat) a x. Rep (ModN n a) x -> ModN n a
forall (n :: Nat) a x. ModN n a -> Rep (ModN n a) x
forall x. Rep (ModN n a) x -> ModN n a
forall x. ModN n a -> Rep (ModN n a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (n :: Nat) a x. ModN n a -> Rep (ModN n a) x
from :: forall x. ModN n a -> Rep (ModN n a) x
$cto :: forall (n :: Nat) a x. Rep (ModN n a) x -> ModN n a
to :: forall x. Rep (ModN n a) x -> ModN n a
Generic,
      -- | @since 0.1
      (forall (m :: Type -> Type). Quote m => ModN n a -> m Exp)
-> (forall (m :: Type -> Type).
    Quote m =>
    ModN n a -> Code m (ModN n a))
-> Lift (ModN n a)
forall (n :: Nat) a (m :: Type -> Type).
(Lift a, Quote m) =>
ModN n a -> m Exp
forall (n :: Nat) a (m :: Type -> Type).
(Lift a, Quote m) =>
ModN n a -> Code m (ModN n a)
forall t.
(forall (m :: Type -> Type). Quote m => t -> m Exp)
-> (forall (m :: Type -> Type). Quote m => t -> Code m t) -> Lift t
forall (m :: Type -> Type). Quote m => ModN n a -> m Exp
forall (m :: Type -> Type).
Quote m =>
ModN n a -> Code m (ModN n a)
$clift :: forall (n :: Nat) a (m :: Type -> Type).
(Lift a, Quote m) =>
ModN n a -> m Exp
lift :: forall (m :: Type -> Type). Quote m => ModN n a -> m Exp
$cliftTyped :: forall (n :: Nat) a (m :: Type -> Type).
(Lift a, Quote m) =>
ModN n a -> Code m (ModN n a)
liftTyped :: forall (m :: Type -> Type).
Quote m =>
ModN n a -> Code m (ModN n a)
Lift,
      -- | @since 0.1
      Eq (ModN n a)
Eq (ModN n a) =>
(ModN n a -> ModN n a -> Ordering)
-> (ModN n a -> ModN n a -> Bool)
-> (ModN n a -> ModN n a -> Bool)
-> (ModN n a -> ModN n a -> Bool)
-> (ModN n a -> ModN n a -> Bool)
-> (ModN n a -> ModN n a -> ModN n a)
-> (ModN n a -> ModN n a -> ModN n a)
-> Ord (ModN n a)
ModN n a -> ModN n a -> Bool
ModN n a -> ModN n a -> Ordering
ModN n a -> ModN n a -> ModN n a
forall (n :: Nat) a. Ord a => Eq (ModN n a)
forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Bool
forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Ordering
forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> ModN n a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Ordering
compare :: ModN n a -> ModN n a -> Ordering
$c< :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Bool
< :: ModN n a -> ModN n a -> Bool
$c<= :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Bool
<= :: ModN n a -> ModN n a -> Bool
$c> :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Bool
> :: ModN n a -> ModN n a -> Bool
$c>= :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> Bool
>= :: ModN n a -> ModN n a -> Bool
$cmax :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> ModN n a
max :: ModN n a -> ModN n a -> ModN n a
$cmin :: forall (n :: Nat) a. Ord a => ModN n a -> ModN n a -> ModN n a
min :: ModN n a -> ModN n a -> ModN n a
Ord
    )
  deriving anyclass
    ( -- | @since 0.1
      ModN n a
ModN n a -> LowerBounded (ModN n a)
forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
ModN n a
forall a. a -> LowerBounded a
$clowerBound :: forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
ModN n a
lowerBound :: ModN n a
LowerBounded,
      -- | @since 0.1
      ModN n a -> ()
(ModN n a -> ()) -> NFData (ModN n a)
forall (n :: Nat) a. NFData a => ModN n a -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall (n :: Nat) a. NFData a => ModN n a -> ()
rnf :: ModN n a -> ()
NFData,
      -- | @since 0.1
      ModN n a
ModN n a -> UpperBounded (ModN n a)
forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
ModN n a
forall a. a -> UpperBounded a
$cupperBound :: forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
ModN n a
upperBound :: ModN n a
UpperBounded
    )

-- | @since 0.1
instance (KnownNat n, Show a) => Show (ModN n a) where
  -- manual so we include the mod string
  showsPrec :: Int -> ModN n a -> ShowS
showsPrec Int
i (UnsafeModN a
x) =
    Bool -> ShowS -> ShowS
showParen
      (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
11)
      (String -> ShowS
showString String
"MkModN " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 a
x ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
modStr)
    where
      modStr :: String
modStr = String
" (mod " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Nat -> String
forall a. Show a => a -> String
show Nat
n' String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"
      n' :: Nat
n' = forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
  {-# INLINEABLE showsPrec #-}

-- | @since 0.1
instance HasField "unModN" (ModN n a) a where
  getField :: ModN n a -> a
getField (UnsafeModN a
x) = a
x

-- | @since 0.1
instance
  ( k ~ A_Getter,
    x ~ a,
    y ~ a
  ) =>
  LabelOptic "unModN" k (ModN n a) (ModN n a) x y
  where
  labelOptic :: Optic k NoIx (ModN n a) (ModN n a) x y
labelOptic = (ModN n a -> x) -> Getter (ModN n a) x
forall s a. (s -> a) -> Getter s a
to (\(UnsafeModN a
x) -> x
a
x)
  {-# INLINE labelOptic #-}

-- | Bidirectional pattern synonym for 'ModN'. Construction will apply
-- modular reduction to the parameter.
--
-- @since 0.1
pattern MkModN :: a -> ModN n a
pattern $mMkModN :: forall {r} {a} {n :: Nat}.
ModN n a -> (a -> r) -> ((# #) -> r) -> r
MkModN x <- UnsafeModN x

{-# COMPLETE MkModN #-}

-- | __WARNING: Partial__
--
-- @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  Bounded (ModN n a)
  where
  minBound :: ModN n a
minBound = a -> ModN n a
forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN a
0
  maxBound :: ModN n a
maxBound = a -> ModN n a
forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN (a -> ModN n a) -> a -> ModN n a
forall a b. (a -> b) -> a -> b
$ Nat -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
- Nat
1)
  {-# INLINEABLE minBound #-}
  {-# INLINEABLE maxBound #-}

-- | @since 0.1
instance (KnownNat n, Show a) => Display (ModN n a) where
  displayBuilder :: ModN n a -> Builder
displayBuilder (UnsafeModN a
x) =
    [Builder] -> Builder
forall a. Monoid a => [a] -> a
mconcat
      [ String -> Builder
forall a. Display a => a -> Builder
displayBuilder (String -> Builder) -> String -> Builder
forall a b. (a -> b) -> a -> b
$ a -> String
forall a. Show a => a -> String
show a
x,
        forall a. Display a => a -> Builder
displayBuilder @String String
" (mod ",
        String -> Builder
forall a. Display a => a -> Builder
displayBuilder (String -> Builder) -> String -> Builder
forall a b. (a -> b) -> a -> b
$ Nat -> String
forall a. Show a => a -> String
show Nat
n',
        forall a. Display a => a -> Builder
displayBuilder @String String
")"
      ]
    where
      n' :: Nat
n' = forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy

-- | @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a
  ) =>
  ASemigroup (ModN n a)
  where
  UnsafeModN a
x .+. :: ModN n a -> ModN n a -> ModN n a
.+. UnsafeModN a
y =
    a -> ModN n a
forall (n :: Nat) a. a -> ModN n a
UnsafeModN (a -> ModN n a) -> a -> ModN n a
forall a b. (a -> b) -> a -> b
$ a -> a -> a -> a
forall a. (Integral a, MaybeUpperBounded a) => a -> a -> a -> a
Utils.modSafeAdd a
x a
y (Nat -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Nat
n')
    where
      n' :: Nat
n' = forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
  {-# INLINEABLE (.+.) #-}

-- | __WARNING: Partial__
--
-- @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  AMonoid (ModN n a)
  where
  zero :: ModN n a
zero = a -> ModN n a
forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN a
0
  {-# INLINEABLE zero #-}

-- | @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeLowerBounded a,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  AGroup (ModN n a)
  where
  UnsafeModN a
x .-. :: ModN n a -> ModN n a -> ModN n a
.-. UnsafeModN a
y =
    a -> ModN n a
forall (n :: Nat) a. a -> ModN n a
UnsafeModN (a -> ModN n a) -> a -> ModN n a
forall a b. (a -> b) -> a -> b
$ a -> a -> a -> a
forall a. (Integral a, MaybeLowerBounded a) => a -> a -> a -> a
Utils.modSafeSub a
x a
y (Nat -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Nat
n')
    where
      n' :: Nat
n' = forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
  {-# INLINEABLE (.-.) #-}

-- | @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a
  ) =>
  MSemigroup (ModN n a)
  where
  UnsafeModN a
x .*. :: ModN n a -> ModN n a -> ModN n a
.*. UnsafeModN a
y =
    a -> ModN n a
forall (n :: Nat) a. a -> ModN n a
UnsafeModN (a -> ModN n a) -> a -> ModN n a
forall a b. (a -> b) -> a -> b
$ a -> a -> a -> a
forall a. (Integral a, MaybeUpperBounded a) => a -> a -> a -> a
Utils.modSafeMult a
x a
y (Nat -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Nat
n')
    where
      n' :: Nat
n' = forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
  {-# INLINEABLE (.*.) #-}

-- | __WARNING: Partial__
--
-- @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  MMonoid (ModN n a)
  where
  one :: ModN n a
one = a -> ModN n a
forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN a
1
  {-# INLINEABLE one #-}

-- | @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  Semiring (ModN n a)

-- | @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeLowerBounded a,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  Ring (ModN n a)

-- | __WARNING: Partial__
--
-- @since 0.1
instance
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  FromInteger (ModN n a)
  where
  afromInteger :: HasCallStack => Integer -> ModN n a
afromInteger = a -> ModN n a
forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN (a -> ModN n a) -> (Integer -> a) -> Integer -> ModN n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
  {-# INLINEABLE afromInteger #-}

-- | Constructor for 'ModN'.
--
-- ==== __Examples__
-- >>> mkModN @5 7
-- Right (MkModN 2 (mod 5))
--
-- >>> mkModN @10 7
-- Right (MkModN 7 (mod 10))
--
-- >>> mkModN @128 (9 :: Int8)
-- Left "Type 'Int8' has a maximum size of 127. This is not large enough to safely implement mod 128."
--
-- @since 0.1
mkModN ::
  forall n a.
  ( Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  a ->
  Either String (ModN n a)
mkModN :: forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
a -> Either String (ModN n a)
mkModN a
x = Either String (ModN n a)
-> (String -> Either String (ModN n a))
-> Maybe String
-> Either String (ModN n a)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Either String (ModN n a)
modN String -> Either String (ModN n a)
forall a b. a -> Either a b
Left (a -> Integer -> Maybe String
forall a.
(Integral a, MaybeUpperBounded a, Typeable a) =>
a -> Integer -> Maybe String
Utils.checkModBound a
x Integer
n')
  where
    modN :: Either String (ModN n a)
modN = ModN n a -> Either String (ModN n a)
forall a b. b -> Either a b
Right ModN n a
x'
    n' :: Integer
n' = Nat -> Integer
forall a. Integral a => a -> Integer
toInteger (Nat -> Integer) -> Nat -> Integer
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
    x' :: ModN n a
x' = a -> ModN n a
forall (n :: Nat) a. a -> ModN n a
UnsafeModN (a -> ModN n a) -> a -> ModN n a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Integral a => a -> a -> a
`mod` Integer -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n'
{-# INLINEABLE mkModN #-}

-- | Variant of 'mkModN' that throws an error when type @a@ is not
-- large enough to fit @n@.
--
-- __WARNING: Partial__
--
-- ==== __Examples__
-- >>> unsafeModN @7 12
-- MkModN 5 (mod 7)
--
-- @since 0.1
unsafeModN ::
  forall n a.
  ( HasCallStack,
    Integral a,
    KnownNat n,
    MaybeUpperBounded a,
    Typeable a
  ) =>
  a ->
  ModN n a
unsafeModN :: forall (n :: Nat) a.
(HasCallStack, Integral a, KnownNat n, MaybeUpperBounded a,
 Typeable a) =>
a -> ModN n a
unsafeModN a
x = case a -> Either String (ModN n a)
forall (n :: Nat) a.
(Integral a, KnownNat n, MaybeUpperBounded a, Typeable a) =>
a -> Either String (ModN n a)
mkModN a
x of
  Right ModN n a
mp -> ModN n a
mp
  Left String
err -> String -> ModN n a
forall a. HasCallStack => String -> a
error (String -> ModN n a) -> String -> ModN n a
forall a b. (a -> b) -> a -> b
$ String -> ShowS
errMsg String
"unsafeModN" String
err
{-# INLINEABLE unsafeModN #-}

-- | This function reduces the argument modulo @p@ but does __not__ check
-- that @n@ fits within a. Note that correct behavior requires this, so this
-- is dangerous. This is intended only for when we absolutely know @n@ fits in
-- @a@ and the check is undesirable for performance reasons. Exercise extreme
-- caution.
--
-- @since 0.1
reallyUnsafeModN :: forall n a. (Integral a, KnownNat n) => a -> ModN n a
reallyUnsafeModN :: forall (n :: Nat) a. (Integral a, KnownNat n) => a -> ModN n a
reallyUnsafeModN = a -> ModN n a
forall (n :: Nat) a. a -> ModN n a
UnsafeModN (a -> ModN n a) -> (a -> a) -> a -> ModN n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> a -> a
forall a. Integral a => a -> a -> a
`mod` a
n')
  where
    n' :: a
n' = Nat -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> a) -> Nat -> a
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @n Proxy n
forall {k} (t :: k). Proxy t
Proxy
{-# INLINEABLE reallyUnsafeModN #-}

-- | @since 0.1
errMsg :: String -> String -> String
errMsg :: String -> ShowS
errMsg String
fn String
msg =
  [String] -> String
forall a. Monoid a => [a] -> a
mconcat
    [ String
"Numeric.Data.ModN.",
      String
fn,
      String
": ",
      String
msg
    ]