{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-identities #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- See the note on Modulus for why this warning is disabled

-- | Internal tools for modular arithmetic and primality testing. The main
-- functions are 'isPrime' and 'findInverse', though others are exported for
-- testing.
-- @since 0.1
module Numeric.Data.ModP.Internal.Primality
  ( -- * Primality Testing
    MaybePrime (..),

    -- ** Arithmoi vs. Default
    -- $arithmoi

    -- ** Helper Types
    -- $primality-helper
    Modulus (..),
    Pow (..),
    Mult (..),
    Rand (..),

    -- ** Helper Functions

    -- * Multiplicative Inverses

    -- ** Arithmoi vs. Default

    -- ** Types / Low-level
    Bezout (..),
    R (..),
    S (..),
    T (..),

    -- * Misc

import Control.DeepSeq (NFData)
import Data.Data (Proxy (Proxy))
import Data.Kind (Type)
import GHC.Generics (Generic)
import GHC.TypeNats (KnownNat, natVal)
import Numeric.Natural (Natural)
import System.Random (UniformRange)
import System.Random qualified as Rand
import System.Random.Stateful qualified as RandState

import Data.Mod (Mod)
import Data.Mod qualified as Mod
import Math.NumberTheory.Primes.Testing qualified as AM.Primes.Testing

-- | Result of running Miller-Rabin algorithm. At best we can determine if
-- some @n@ is definitely composite or "probably prime".
-- @since 0.1
type MaybePrime :: Type
data MaybePrime
  = Composite
  | ProbablyPrime
instance Semigroup MaybePrime where
Composite <> :: MaybePrime -> MaybePrime -> MaybePrime
<> MaybePrime
_ = MaybePrime
ProbablyPrime <> MaybePrime
r = MaybePrime

instance Monoid MaybePrime where
  mempty :: MaybePrime
mempty = MaybePrime

-- TODO: Turns out, isPrime is slow. For example, isPrime 1_000_003 takes
-- quite a long time. Basic profiling (:set +s in ghci) shows that memory
-- scales with the prime. This is probably much worse than it should be.
-- The optional flag arithmoi enables the arithmoi package, which is much
-- faster. But it would be nice if we could improve the default as well.

-- | Tests primality via the Miller-Rabin algorithm with 100 trials. Returns
-- 'Composite' if the number is definitely composite, otherwise
-- 'ProbablyPrime'.
-- ==== __Examples__
-- >>> isPrime 7
-- ProbablyPrime
-- >>> isPrime 22
-- Composite
-- >>> isPrime 373
-- ProbablyPrime
-- @since 0.1
isPrime :: Integer -> MaybePrime
isPrime = isPrimeArithmoi
isPrime :: Integer -> MaybePrime
isPrime = Integer -> MaybePrime
{-# INLINEABLE isPrime #-}

-- $arithmoi
-- By default, our isPrime function implements miller-rabin directly.
-- Unfortunately, the performance scales poorly (this is an issue with our
-- implementation, not miller-rabin). Thus we provide the optional flag
-- @arithmoi@ that instead uses the arithmoi package. Arithmoi is much faster,
-- though it is not a light dependency, hence the option.
-- In other words, the flag controls the tradeoff between isPrime speed vs.
-- dependency footprint. So why do we also provide isPrimeDefault and
-- isPrimeArithmoi? Benchmarking. We want to benchmark the difference, hence
-- we need both available when the flag is on.

-- | Uses arithmoi if available, otherwise errors.
isPrimeArithmoi :: Integer -> MaybePrime
isPrimeArithmoi n =
  if AM.Primes.Testing.isPrime n
    then ProbablyPrime
    else Composite
isPrimeArithmoi :: Integer -> MaybePrime
isPrimeArithmoi =
  String -> Integer -> MaybePrime
forall a. HasCallStack => String -> a
    (String -> Integer -> MaybePrime)
-> String -> Integer -> MaybePrime
forall a b. (a -> b) -> a -> b
$ String -> ShowS
"arithmoi flag is disabled. Either turn the flag on or use one of isPrime, isPrimeDefault."

-- | 'isPrimeTrials' with 100 trials.
isPrimeDefault :: Integer -> MaybePrime
isPrimeDefault :: Integer -> MaybePrime
isPrimeDefault = Int -> Integer -> MaybePrime
isPrimeTrials Int

-- | 'isPrime' that takes in an additional 'Int' parameter for the number
-- of trials to run. The more trials, the more confident we can be in
-- 'ProbablyPrime'.
-- ==== __Examples__
-- >>> isPrimeTrials 1 91
-- ProbablyPrime
-- >>> isPrimeTrials 2 91
-- Composite
-- Note: False positives can be found via:
-- @
-- -- search for \"ProbablyPrime\" after 1 trial in the composite sequence
-- -- for a given prime p
-- counter p = filter ((== ProbablyPrime) . snd) $
--   fmap (\x -> (x, isPrimeTrials 1 x)) [p + p, p + p + p ..]
-- @
-- @since 0.1
isPrimeTrials :: Int -> Integer -> MaybePrime
isPrimeTrials :: Int -> Integer -> MaybePrime
isPrimeTrials Int
_ Integer
1 = MaybePrime
isPrimeTrials Int
_ Integer
2 = MaybePrime
isPrimeTrials Int
numTrials Integer
  | Integer -> Bool
forall a. Integral a => a -> Bool
even Integer
n = MaybePrime
  | Bool
otherwise = Modulus -> Int -> MaybePrime
millerRabin (Integer -> Modulus
MkModulus Integer
n) Int
{-# INLINEABLE isPrimeTrials #-}

-- $primality-helper
-- For the following functions/types, a core concept is rewriting our \(n\) as
-- \[
--   n = 2^r d + 1,
-- \]
-- where \(d\) is odd i.e. we have factored out 2 as much as possible.
-- We use newtypes to track these numbers.

-- | Represents a modulus. When testing for primality, this is the \(n\) in
-- \(n = 2^{r} d + 1\).
-- @since 0.1
type Modulus :: Type
newtype Modulus = MkModulus Integer
-- GHC 9+ is complaining that "Call of toInteger :: Integer -> Integer can
-- probably be omitted" when deriving Integral for all these types in this
-- module. My guess is the derived instance is generating toInteger for some
-- reason. Until we investigate further, disabling the -Widentities warning
-- is the easiest workaround.

-- | The \(r\) in \(n = 2^{r} d + 1\).
-- @since 0.1
type Pow :: Type
newtype Pow = MkPow Integer
-- | The \(d\) in \(n = 2^{r} d + 1\).
-- @since 0.1
type Mult :: Type
newtype Mult = MkMult Integer
-- | Randomly generated \(m \in [2, n - 2] \) for testing \(n\)'s primality.
-- @since 0.1
type Rand :: Type
newtype Rand = MkRand Integer
-- | @since 0.1
instance UniformRange Rand where
  uniformRM :: forall g (m :: Type -> Type).
StatefulGen g m =>
(Rand, Rand) -> g -> m Rand
uniformRM (MkRand Integer
l, MkRand Integer
u) = (Integer -> Rand) -> m Integer -> m Rand
forall a b. (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Rand
MkRand (m Integer -> m Rand) -> (g -> m Integer) -> g -> m Rand
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, Integer) -> g -> m Integer
forall a g (m :: Type -> Type).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: Type -> Type).
StatefulGen g m =>
(Integer, Integer) -> g -> m Integer
RandState.uniformRM (Integer
l, Integer
  {-# INLINEABLE uniformRM #-}

-- | Miller-Rabin algorithm. Takes in the \(n\) to be tested and the number
-- of trials to perform. The higher the trials, the higher our confidence
-- in 'ProbablyPrime'.
millerRabin :: Modulus -> Int -> MaybePrime
millerRabin :: Modulus -> Int -> MaybePrime
millerRabin Modulus
2 = MaybePrime -> Int -> MaybePrime
forall a b. a -> b -> a
const MaybePrime
millerRabin modulus :: Modulus
modulus@(MkModulus Integer
n) = StdGen -> Int -> MaybePrime
go StdGen
    gen :: StdGen
gen = Int -> StdGen
Rand.mkStdGen Int
    powMult :: (Pow, Mult)
powMult = Modulus -> (Pow, Mult)
factor2 (Modulus
modulus Modulus -> Modulus -> Modulus
forall a. Num a => a -> a -> a
- Modulus
    range :: StateGenM StdGen -> StateT StdGen Identity Rand
range = (Rand, Rand) -> StateGenM StdGen -> StateT StdGen Identity Rand
forall a g (m :: Type -> Type).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
forall g (m :: Type -> Type).
StatefulGen g m =>
(Rand, Rand) -> g -> m Rand
RandState.uniformRM (Rand
2, Integer -> Rand
MkRand (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer

    go :: StdGen -> Int -> MaybePrime
go StdGen
_ Int
0 = MaybePrime
    go StdGen
g !Int
k =
      let (Rand
randomVal, StdGen
g') = StdGen
-> (StateGenM StdGen -> StateT StdGen Identity Rand)
-> (Rand, StdGen)
forall g a.
RandomGen g =>
g -> (StateGenM g -> State g a) -> (a, g)
RandState.runStateGen StdGen
g StateGenM StdGen -> StateT StdGen Identity Rand
       in case Modulus -> (Pow, Mult) -> Rand -> MaybePrime
trial Modulus
modulus (Pow, Mult)
powMult Rand
randomVal of
Composite -> MaybePrime
ProbablyPrime -> StdGen -> Int -> MaybePrime
go StdGen
g' (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
{-# INLINEABLE millerRabin #-}

-- | For \(n, r, d, x\) with \(n = 2^{r} d + 1\) and \(x \in [2, n - 2] \),
-- returns 'Composite' if \(n\) is definitely composite, 'ProbablyPrime'
-- otherwise.
-- ==== __Examples__
-- >>> trial 12 (factor2 (12 - 1)) 3
-- Composite
-- >>> trial 7 (factor2 (7 - 1)) 3
-- ProbablyPrime
-- @since 0.1
trial :: Modulus -> (Pow, Mult) -> Rand -> MaybePrime
trial :: Modulus -> (Pow, Mult) -> Rand -> MaybePrime
trial modulus :: Modulus
modulus@(MkModulus Integer
n) (Pow
r, Mult
d) (MkRand Integer
  -- x = 1 or n - 1 -> skip
  | Integer
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1 Bool -> Bool -> Bool
|| Integer
x Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 = MaybePrime
  -- if we found a witness then n is definitely composite
  | Bool
otherwise = Modulus -> Pow -> Rand -> MaybePrime
isWitness Modulus
modulus Pow
r (Integer -> Rand
MkRand Integer
    x :: Integer
x = Integer
a Integer -> Mult -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Mult
d Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
{-# INLINEABLE trial #-}

-- | For \(n, r, x\) with \(n = 2^{r} d + 1\) and some
-- \(x \equiv a^d \pmod n \), returns 'Composite' if \(x\) is a witness to
-- \(n\) being composite. Otherwise returns 'ProbablyPrime'.
-- ==== __Examples__
-- >>> let (pow, mult) = factor2 (12 - 1)
-- >>> let testVal = 3 ^ mult `mod` 12
-- >>> isWitness 12 pow testVal
-- Composite
-- >>> let (pow, mult) = factor2 (7 - 1)
-- >>> let testVal = 3 ^ mult `mod` 7
-- >>> isWitness 7 pow testVal
-- ProbablyPrime
-- @since 0.1
isWitness :: Modulus -> Pow -> Rand -> MaybePrime
isWitness :: Modulus -> Pow -> Rand -> MaybePrime
isWitness modulus :: Modulus
modulus@(MkModulus Integer
n) Pow
r (MkRand Integer
x) = Bool -> MaybePrime
coprimeToResult Bool
    squares :: [Integer]
squares = Int -> [Integer] -> [Integer]
forall a. Int -> [a] -> [a]
take (Pow -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Pow
r) ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Modulus -> Integer -> [Integer]
sqProgression Modulus
modulus Integer
    coprime :: Bool
coprime = (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1) Integer -> [Integer] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: Type -> Type) a.
(Foldable t, Eq a) =>
a -> t a -> Bool
`elem` [Integer]
    coprimeToResult :: Bool -> MaybePrime
coprimeToResult Bool
True = MaybePrime
    coprimeToResult Bool
False = MaybePrime
{-# INLINEABLE isWitness #-}

-- | For \(n, x\), returns the infinite progression
-- \[
-- x, x^2, x^4, x^8, \ldots \pmod n.
-- \]
-- ==== __Examples__
-- >>> take 5 $ sqProgression 7 3
-- [3,2,4,2,4]
-- @since 0.1
sqProgression :: Modulus -> Integer -> [Integer]
sqProgression :: Modulus -> Integer -> [Integer]
sqProgression (MkModulus Integer
n) = Integer -> [Integer]
    go :: Integer -> [Integer]
go !Integer
y = Integer
y Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
: Integer -> [Integer]
go (Integer
y Integer -> Int -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
2 :: Int) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
{-# INLINEABLE sqProgression #-}

-- | Given \(n\), returns \((r, d)\) such that \(n = 2^r d\) with \(d\) odd
-- i.e. \(2\) has been factored out.
-- ==== __Examples__
-- >>> factor2 7
-- (MkPow 0,MkMult 7)
-- >>> factor2 8
-- (MkPow 3,MkMult 1)
-- >>> factor2 20
-- (MkPow 2,MkMult 5)
-- @since 0.1
factor2 :: Modulus -> (Pow, Mult)
factor2 :: Modulus -> (Pow, Mult)
factor2 (MkModulus Integer
n) = (Pow, Mult) -> (Pow, Mult)
forall {b} {a}. (Num a, Integral b) => (a, b) -> (a, b)
go (Integer -> Pow
MkPow Integer
0, Integer -> Mult
MkMult Integer
    go :: (a, b) -> (a, b)
go (!a
r, !b
      | b
d b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== b
2 = (a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, b
      | b -> Bool
forall a. Integral a => a -> Bool
even b
d = (a, b) -> (a, b)
go (a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a
1, b
d b -> b -> b
forall a. Integral a => a -> a -> a
`div` b
      | Bool
otherwise = (a
r, b
{-# INLINEABLE factor2 #-}


-- | Finds the multiplicative inverse.
invert :: forall p. (KnownNat p) => Natural -> Natural
invert = invertArithmoi @p
invert :: forall (p :: Nat). KnownNat p => Nat -> Nat
invert = forall (p :: Nat). KnownNat p => Nat -> Nat
invertDefault @p
{-# INLINEABLE invert #-}

-- | Finds the multiplicative inverse with arithmoi if available, otherwise
-- errors.
invertArithmoi :: forall p. (KnownNat p) => Natural -> Natural
invertArithmoi d =
  case Mod.invertMod (fromIntegral d :: Mod p) of
    Nothing ->
        $ errMsg
            ("Could not find inverse of " ++ (show d) ++ " (mod " ++ show p' ++ ")")
    Just n -> Mod.unMod n
    p' = natVal @p Proxy
invertArithmoi :: forall (p :: Nat). KnownNat p => Nat -> Nat
invertArithmoi =
  String -> Nat -> Nat
forall a. HasCallStack => String -> a
    (String -> Nat -> Nat) -> String -> Nat -> Nat
forall a b. (a -> b) -> a -> b
$ String -> ShowS
"arithmoi flag is disabled. Either turn the flag on or use one of invert, invertDefault."

-- | Finds the multiplicative inverse using the built-in algorithm.
invertDefault :: forall p. (KnownNat p) => Natural -> Natural
invertDefault :: forall (p :: Nat). KnownNat p => Nat -> Nat
invertDefault Nat
d = Integer -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Nat) -> Integer -> Nat
forall a b. (a -> b) -> a -> b
$ Integer -> Modulus -> Integer
findInverse Integer
d' Modulus
    p' :: Modulus
p' = Integer -> Modulus
MkModulus (Integer -> Modulus) -> Integer -> Modulus
forall a b. (a -> b) -> a -> b
$ Nat -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Integer) -> Nat -> Integer
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Nat
natVal @p Proxy p
forall {k} (t :: k). Proxy t
    d' :: Integer
d' = Nat -> Integer
forall a. Integral a => a -> Integer
toInteger Nat


-- | For \(a, p\), finds the multiplicative inverse of \(a\) in
-- \(\mathbb{Z}/p\mathbb{Z}\). That is, finds /e/ such that
-- \[
-- ae \equiv 1 \pmod p.
-- \]
-- Note: The returned \(e\) is only an inverse when \(a\) and \(p\) are
-- coprime i.e. \((a,p) = 1\). Of course this is guaranteed when \(p\) is
-- prime and \(0 < a < p \), but it otherwise not true in general.
-- Also, this function requires division, it is partial when
-- the modulus is 0.
-- @since 0.1
findInverse :: Integer -> Modulus -> Integer
findInverse :: Integer -> Modulus -> Integer
findInverse Integer
a (MkModulus Integer
p) = Integer
aInv Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
    (MkBezout R
_ S
_ (T' Integer
aInv)) = Integer -> Integer -> Bezout
eec Integer
p Integer
{-# INLINEABLE findInverse #-}

-- | @since 0.1
findBezout :: Integer -> Modulus -> Bezout
findBezout :: Integer -> Modulus -> Bezout
findBezout Integer
a (MkModulus Integer
p) = Integer -> Integer -> Bezout
eec Integer
p Integer
{-# INLINEABLE findBezout #-}

-- | @since 0.1t
type Bezout :: Type
data Bezout = MkBezout
  { Bezout -> R
bzGcd :: !R,
    Bezout -> S
bzS :: !S,
    Bezout -> T
bzT :: !T
-- | @since 0.1
type R :: Type
newtype R = R' Integer
-- | @since 0.1
type S :: Type
newtype S = S' Integer
-- | @since 0.1
type T :: Type
newtype T = T' Integer
-- Solves for Bezout's identity using the extended euclidean algorithm:
-- https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm#Pseudocode
eec :: Integer -> Integer -> Bezout
eec :: Integer -> Integer -> Bezout
eec Integer
a Integer
b = R -> R -> S -> S -> T -> T -> Bezout
go R
initOldR R
initR S
initOldS S
initS T
initOldT T
initOldR, R
initR) = (Integer -> R
R' Integer
a, Integer -> R
R' Integer
initOldS, S
initS) = (Integer -> S
S' Integer
1, Integer -> S
S' Integer
initOldT, T
initT) = (Integer -> T
T' Integer
0, Integer -> T
T' Integer

    go :: R -> R -> S -> S -> T -> T -> Bezout
go R
oldR R
0 S
oldS S
_ T
oldT T
_ = R -> S -> T -> Bezout
MkBezout R
oldR S
oldS T
    go !R
oldR !R
r !S
oldS !S
s !T
oldT !T
t =
      let oldR' :: R
oldR' = R
          oldS' :: S
oldS' = S
          oldT' :: T
oldT' = T
          (R' Integer
q, R
r') = R
oldR R -> R -> (R, R)
forall a. Integral a => a -> a -> (a, a)
`quotRem` R
          s' :: S
s' = S
oldS S -> S -> S
forall a. Num a => a -> a -> a
- Integer -> S
S' Integer
q S -> S -> S
forall a. Num a => a -> a -> a
* S
          t' :: T
t' = T
oldT T -> T -> T
forall a. Num a => a -> a -> a
- Integer -> T
T' Integer
q T -> T -> T
forall a. Num a => a -> a -> a
* T
       in R -> R -> S -> S -> T -> T -> Bezout
go R
oldR' R
r' S
oldS' S
s' T
oldT' T
{-# INLINEABLE eec #-}

-- | @since 0.1
errMsg :: String -> String -> String
errMsg :: String -> ShowS
errMsg String
fn String
msg =
  [String] -> String
forall a. Monoid a => [a] -> a
    [ String
": ",