{-# LANGUAGE UndecidableInstances #-}

-- | Provides the 'NonNegative' type for enforcing a nonnegative invariant.
--
-- @since 0.1
module Numeric.Data.NonNegative.Internal
  ( -- * Type
    NonNegative (MkNonNegative, UnsafeNonNegative),

    -- * Creation
    unsafeNonNegative,

    -- * Misc
    errMsg,
  )
where

import Control.DeepSeq (NFData)
import Data.Bifunctor (Bifunctor (bimap))
import Data.Bounds
  ( LowerBounded (lowerBound),
    UpperBounded (upperBound),
    UpperBoundless,
  )
import Data.Kind (Type)
import Data.Text.Display (Display, ShowInstance (ShowInstance))
import GHC.Generics (Generic)
import GHC.Records (HasField (getField))
import GHC.Stack (HasCallStack)
import Language.Haskell.TH.Syntax (Lift)
import Numeric.Algebra.Additive.AMonoid (AMonoid (zero))
import Numeric.Algebra.Additive.ASemigroup (ASemigroup ((.+.)))
import Numeric.Algebra.Multiplicative.MEuclidean (MEuclidean (mdivMod))
import Numeric.Algebra.Multiplicative.MGroup (MGroup ((.%.)))
import Numeric.Algebra.Multiplicative.MMonoid (MMonoid (one))
import Numeric.Algebra.Multiplicative.MSemigroup (MSemigroup ((.*.)))
import Numeric.Algebra.Normed (Normed (norm))
import Numeric.Algebra.Semifield (Semifield)
import Numeric.Algebra.Semiring (Semiring)
import Numeric.Class.Division (Division (divide))
import Numeric.Literal.Integer (FromInteger (afromInteger))
import Numeric.Literal.Rational (FromRational (afromRational))
import Optics.Core (A_Getter, LabelOptic (labelOptic), to)

-- $setup
-- >>> :set -XTemplateHaskell
-- >>> :set -XPostfixOperators

-- | Newtype wrapper that attaches a 'NonNegative' invariant to some @a@.
-- 'NonNegative' is a 'Numeric.Algebra.Semifield.Semifield' i.e. supports
-- addition, multiplication, and division.
--
-- @since 0.1
type NonNegative :: Type -> Type
newtype NonNegative a = UnsafeNonNegative a
  deriving stock
    ( -- | @since 0.1
      NonNegative a -> NonNegative a -> Bool
(NonNegative a -> NonNegative a -> Bool)
-> (NonNegative a -> NonNegative a -> Bool) -> Eq (NonNegative a)
forall a. Eq a => NonNegative a -> NonNegative a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => NonNegative a -> NonNegative a -> Bool
== :: NonNegative a -> NonNegative a -> Bool
$c/= :: forall a. Eq a => NonNegative a -> NonNegative a -> Bool
/= :: NonNegative a -> NonNegative a -> Bool
Eq,
      -- | @since 0.1
      (forall x. NonNegative a -> Rep (NonNegative a) x)
-> (forall x. Rep (NonNegative a) x -> NonNegative a)
-> Generic (NonNegative a)
forall x. Rep (NonNegative a) x -> NonNegative a
forall x. NonNegative a -> Rep (NonNegative a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (NonNegative a) x -> NonNegative a
forall a x. NonNegative a -> Rep (NonNegative a) x
$cfrom :: forall a x. NonNegative a -> Rep (NonNegative a) x
from :: forall x. NonNegative a -> Rep (NonNegative a) x
$cto :: forall a x. Rep (NonNegative a) x -> NonNegative a
to :: forall x. Rep (NonNegative a) x -> NonNegative a
Generic,
      -- | @since 0.1
      (forall (m :: Type -> Type). Quote m => NonNegative a -> m Exp)
-> (forall (m :: Type -> Type).
    Quote m =>
    NonNegative a -> Code m (NonNegative a))
-> Lift (NonNegative a)
forall a (m :: Type -> Type).
(Lift a, Quote m) =>
NonNegative a -> m Exp
forall a (m :: Type -> Type).
(Lift a, Quote m) =>
NonNegative a -> Code m (NonNegative a)
forall t.
(forall (m :: Type -> Type). Quote m => t -> m Exp)
-> (forall (m :: Type -> Type). Quote m => t -> Code m t) -> Lift t
forall (m :: Type -> Type). Quote m => NonNegative a -> m Exp
forall (m :: Type -> Type).
Quote m =>
NonNegative a -> Code m (NonNegative a)
$clift :: forall a (m :: Type -> Type).
(Lift a, Quote m) =>
NonNegative a -> m Exp
lift :: forall (m :: Type -> Type). Quote m => NonNegative a -> m Exp
$cliftTyped :: forall a (m :: Type -> Type).
(Lift a, Quote m) =>
NonNegative a -> Code m (NonNegative a)
liftTyped :: forall (m :: Type -> Type).
Quote m =>
NonNegative a -> Code m (NonNegative a)
Lift,
      -- | @since 0.1
      Eq (NonNegative a)
Eq (NonNegative a) =>
(NonNegative a -> NonNegative a -> Ordering)
-> (NonNegative a -> NonNegative a -> Bool)
-> (NonNegative a -> NonNegative a -> Bool)
-> (NonNegative a -> NonNegative a -> Bool)
-> (NonNegative a -> NonNegative a -> Bool)
-> (NonNegative a -> NonNegative a -> NonNegative a)
-> (NonNegative a -> NonNegative a -> NonNegative a)
-> Ord (NonNegative a)
NonNegative a -> NonNegative a -> Bool
NonNegative a -> NonNegative a -> Ordering
NonNegative a -> NonNegative a -> NonNegative a
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (NonNegative a)
forall a. Ord a => NonNegative a -> NonNegative a -> Bool
forall a. Ord a => NonNegative a -> NonNegative a -> Ordering
forall a. Ord a => NonNegative a -> NonNegative a -> NonNegative a
$ccompare :: forall a. Ord a => NonNegative a -> NonNegative a -> Ordering
compare :: NonNegative a -> NonNegative a -> Ordering
$c< :: forall a. Ord a => NonNegative a -> NonNegative a -> Bool
< :: NonNegative a -> NonNegative a -> Bool
$c<= :: forall a. Ord a => NonNegative a -> NonNegative a -> Bool
<= :: NonNegative a -> NonNegative a -> Bool
$c> :: forall a. Ord a => NonNegative a -> NonNegative a -> Bool
> :: NonNegative a -> NonNegative a -> Bool
$c>= :: forall a. Ord a => NonNegative a -> NonNegative a -> Bool
>= :: NonNegative a -> NonNegative a -> Bool
$cmax :: forall a. Ord a => NonNegative a -> NonNegative a -> NonNegative a
max :: NonNegative a -> NonNegative a -> NonNegative a
$cmin :: forall a. Ord a => NonNegative a -> NonNegative a -> NonNegative a
min :: NonNegative a -> NonNegative a -> NonNegative a
Ord,
      -- | @since 0.1
      Int -> NonNegative a -> ShowS
[NonNegative a] -> ShowS
NonNegative a -> String
(Int -> NonNegative a -> ShowS)
-> (NonNegative a -> String)
-> ([NonNegative a] -> ShowS)
-> Show (NonNegative a)
forall a. Show a => Int -> NonNegative a -> ShowS
forall a. Show a => [NonNegative a] -> ShowS
forall a. Show a => NonNegative a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> NonNegative a -> ShowS
showsPrec :: Int -> NonNegative a -> ShowS
$cshow :: forall a. Show a => NonNegative a -> String
show :: NonNegative a -> String
$cshowList :: forall a. Show a => [NonNegative a] -> ShowS
showList :: [NonNegative a] -> ShowS
Show
    )
  deriving anyclass
    ( -- | @since 0.1
      NonNegative a -> ()
(NonNegative a -> ()) -> NFData (NonNegative a)
forall a. NFData a => NonNegative a -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall a. NFData a => NonNegative a -> ()
rnf :: NonNegative a -> ()
NFData,
      -- | @since 0.1
      UpperBoundless (NonNegative a)
forall a. UpperBoundless a
UpperBoundless
    )
  deriving
    ( -- | @since 0.1
      Int -> NonNegative a -> Builder
[NonNegative a] -> Builder
NonNegative a -> Builder
(NonNegative a -> Builder)
-> ([NonNegative a] -> Builder)
-> (Int -> NonNegative a -> Builder)
-> Display (NonNegative a)
forall a. Show a => Int -> NonNegative a -> Builder
forall a. Show a => [NonNegative a] -> Builder
forall a. Show a => NonNegative a -> Builder
forall a.
(a -> Builder)
-> ([a] -> Builder) -> (Int -> a -> Builder) -> Display a
$cdisplayBuilder :: forall a. Show a => NonNegative a -> Builder
displayBuilder :: NonNegative a -> Builder
$cdisplayList :: forall a. Show a => [NonNegative a] -> Builder
displayList :: [NonNegative a] -> Builder
$cdisplayPrec :: forall a. Show a => Int -> NonNegative a -> Builder
displayPrec :: Int -> NonNegative a -> Builder
Display
    )
    via (ShowInstance a)

-- | @since 0.1
instance HasField "unNonNegative" (NonNegative a) a where
  getField :: NonNegative a -> a
getField (UnsafeNonNegative a
x) = a
x

-- | @since 0.1
instance
  ( k ~ A_Getter,
    x ~ a,
    y ~ a
  ) =>
  LabelOptic "unNonNegative" k (NonNegative a) (NonNegative a) x y
  where
  labelOptic :: Optic k NoIx (NonNegative a) (NonNegative a) x y
labelOptic = (NonNegative a -> x) -> Getter (NonNegative a) x
forall s a. (s -> a) -> Getter s a
to (\(UnsafeNonNegative a
x) -> x
a
x)
  {-# INLINE labelOptic #-}

-- | Unidirectional pattern synonym for 'NonNegative'. This allows us to pattern
-- match on a non-negative term without exposing the unsafe internal details.
--
-- @since 0.1
pattern MkNonNegative :: a -> NonNegative a
pattern $mMkNonNegative :: forall {r} {a}. NonNegative a -> (a -> r) -> ((# #) -> r) -> r
MkNonNegative x <- UnsafeNonNegative x

{-# COMPLETE MkNonNegative #-}

-- | @since 0.1
instance (Bounded a, Num a) => Bounded (NonNegative a) where
  minBound :: NonNegative a
minBound = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
0
  maxBound :: NonNegative a
maxBound = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
forall a. Bounded a => a
maxBound
  {-# INLINEABLE minBound #-}
  {-# INLINEABLE maxBound #-}

-- | @since 0.1
instance (Num a) => LowerBounded (NonNegative a) where
  lowerBound :: NonNegative a
lowerBound = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
0
  {-# INLINEABLE lowerBound #-}

-- | @since 0.1
instance (UpperBounded a) => UpperBounded (NonNegative a) where
  upperBound :: NonNegative a
upperBound = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
forall a. UpperBounded a => a
upperBound
  {-# INLINEABLE upperBound #-}

-- | @since 0.1
instance (Num a) => ASemigroup (NonNegative a) where
  UnsafeNonNegative a
x .+. :: NonNegative a -> NonNegative a -> NonNegative a
.+. UnsafeNonNegative a
y = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative (a -> NonNegative a) -> a -> NonNegative a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y
  {-# INLINEABLE (.+.) #-}

-- | @since 0.1
instance (Num a) => AMonoid (NonNegative a) where
  zero :: NonNegative a
zero = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
0
  {-# INLINEABLE zero #-}

-- | @since 0.1
instance (Num a) => MSemigroup (NonNegative a) where
  UnsafeNonNegative a
x .*. :: NonNegative a -> NonNegative a -> NonNegative a
.*. UnsafeNonNegative a
y = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative (a -> NonNegative a) -> a -> NonNegative a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
y
  {-# INLINEABLE (.*.) #-}

-- | @since 0.1
instance (Num a) => MMonoid (NonNegative a) where
  one :: NonNegative a
one = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
1
  {-# INLINEABLE one #-}

-- | @since 0.1
instance (Division a, Num a) => MGroup (NonNegative a) where
  UnsafeNonNegative a
x .%. :: NonNegative a -> NonNegative a -> NonNegative a
.%. (UnsafeNonNegative a
d) = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative (a -> NonNegative a) -> a -> NonNegative a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Division a => a -> a -> a
`divide` a
d
  {-# INLINEABLE (.%.) #-}

-- | @since 0.1
instance (Division a, Integral a) => MEuclidean (NonNegative a) where
  UnsafeNonNegative a
x mdivMod :: NonNegative a -> NonNegative a -> (NonNegative a, NonNegative a)
`mdivMod` (UnsafeNonNegative a
d) =
    (a -> NonNegative a)
-> (a -> NonNegative a) -> (a, a) -> (NonNegative a, NonNegative a)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative ((a, a) -> (NonNegative a, NonNegative a))
-> (a, a) -> (NonNegative a, NonNegative a)
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`divMod` a
d
  {-# INLINEABLE mdivMod #-}

-- | @since 0.1
instance Normed (NonNegative a) where
  norm :: NonNegative a -> NonNegative a
norm = NonNegative a -> NonNegative a
forall a. a -> a
id
  {-# INLINEABLE norm #-}

-- | __WARNING: Partial__
--
-- @since 0.1
instance (Num a, Ord a, Show a) => FromInteger (NonNegative a) where
  afromInteger :: HasCallStack => Integer -> NonNegative a
afromInteger = a -> NonNegative a
forall a.
(HasCallStack, Num a, Ord a, Show a) =>
a -> NonNegative a
unsafeNonNegative (a -> NonNegative a) -> (Integer -> a) -> Integer -> NonNegative a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
  {-# INLINEABLE afromInteger #-}

-- | __WARNING: Partial__
--
-- @since 0.1
instance (Fractional a, Ord a, Show a) => FromRational (NonNegative a) where
  afromRational :: HasCallStack => Rational -> NonNegative a
afromRational = a -> NonNegative a
forall a.
(HasCallStack, Num a, Ord a, Show a) =>
a -> NonNegative a
unsafeNonNegative (a -> NonNegative a)
-> (Rational -> a) -> Rational -> NonNegative a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational
  {-# INLINEABLE afromRational #-}

-- | @since 0.1
instance (Num a) => Semiring (NonNegative a)

-- | @since 0.1
instance (Division a, Num a) => Semifield (NonNegative a)

-- | Throws an error when given a value < 0.
--
-- __WARNING: Partial__
--
-- ==== __Examples__
-- >>> unsafeNonNegative 7
-- UnsafeNonNegative 7
--
-- @since 0.1
unsafeNonNegative :: (HasCallStack, Num a, Ord a, Show a) => a -> NonNegative a
unsafeNonNegative :: forall a.
(HasCallStack, Num a, Ord a, Show a) =>
a -> NonNegative a
unsafeNonNegative a
x
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 = a -> NonNegative a
forall a. a -> NonNegative a
UnsafeNonNegative a
x
  | Bool
otherwise = String -> NonNegative a
forall a. HasCallStack => String -> a
error (String -> NonNegative a) -> String -> NonNegative a
forall a b. (a -> b) -> a -> b
$ String -> a -> String
forall a. Show a => String -> a -> String
errMsg String
"unsafeNonNegative" a
x
{-# INLINEABLE unsafeNonNegative #-}

-- | @since 0.1
errMsg :: (Show a) => String -> a -> String
errMsg :: forall a. Show a => String -> a -> String
errMsg String
fn a
x =
  [String] -> String
forall a. Monoid a => [a] -> a
mconcat
    [ String
"Numeric.Data.NonNegative.",
      String
fn,
      String
": Received value < zero: ",
      a -> String
forall a. Show a => a -> String
show a
x
    ]