gmsol_model/
fixed.rs

1use std::ops::{Add, Mul};
2
3use num_traits::{CheckedAdd, CheckedMul, One, Zero};
4
5use crate::num::{MulDiv, Num};
6
7/// Number type with the required properties for implementing [`Fixed`].
8pub trait FixedPointOps<const DECIMALS: u8>: MulDiv + Num {
9    /// The unit value (i.e. the value "one") which is expected to be `pow(10, DECIMALS)`.
10    const UNIT: Self;
11
12    /// Fixed point power.
13    fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self>;
14}
15
16impl<const DECIMALS: u8> FixedPointOps<DECIMALS> for u64 {
17    const UNIT: Self = 10u64.pow(DECIMALS as u32);
18
19    /// Fixed point power.
20    ///
21    /// # Notes
22    /// The code that calculates exponents behaves inconsistently depending on whether the exponent is a whole unit or not.
23    /// Therefore, to avoid issues, we should use only unit exponents until we implement better algorithms.
24    #[allow(clippy::arithmetic_side_effects)]
25    fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self> {
26        use rust_decimal::{Decimal, MathematicalOps};
27
28        let unit = <Self as FixedPointOps<DECIMALS>>::UNIT;
29        if *exponent % unit == 0 {
30            let exp = exponent / unit;
31            // Note: there is a better algorithm.
32            let mut ans = Fixed::<Self, DECIMALS>::one();
33            let base = Fixed::<Self, DECIMALS>::from_inner(*self);
34            for _ in 0..exp {
35                ans = ans.checked_mul(&base)?;
36            }
37            return Some(ans.0);
38        }
39
40        // `scale > 28` is not supported by `rust_decimal`.
41        if DECIMALS > 28 {
42            return None;
43        }
44        let value = Decimal::new((*self).try_into().ok()?, DECIMALS as u32);
45        let exponent = Decimal::new((*exponent).try_into().ok()?, DECIMALS as u32);
46        let mut ans = value.checked_powd(exponent)?;
47        ans.rescale(DECIMALS as u32);
48        ans.mantissa().try_into().ok()
49    }
50}
51
52#[cfg(feature = "u128")]
53impl<const DECIMALS: u8> FixedPointOps<DECIMALS> for u128 {
54    const UNIT: Self = 10u128.pow(DECIMALS as u32);
55
56    /// Fixed point power.
57    ///
58    /// # Notes
59    /// The code that calculates exponents behaves inconsistently depending on whether the exponent is a whole unit or not.
60    /// Therefore, to avoid issues, we should use only unit exponents until we implement better algorithms.
61    #[allow(clippy::arithmetic_side_effects)]
62    fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self> {
63        use std::cmp::Ordering;
64
65        let unit = <Self as FixedPointOps<DECIMALS>>::UNIT;
66        if *exponent % unit == 0 {
67            let exp = exponent / unit;
68            // Note: there is a better algorithm.
69            let mut ans = Fixed::<Self, DECIMALS>::one();
70            let base = Fixed::<Self, DECIMALS>::from_inner(*self);
71            for _ in 0..exp {
72                ans = ans.checked_mul(&base)?;
73            }
74            return Some(ans.0);
75        }
76
77        type Convert = U64D9;
78
79        let (divisor, multiplier) = match DECIMALS.cmp(&U64D9::DECIMALS) {
80            Ordering::Greater => {
81                let divisor = 10u128.pow((DECIMALS - Convert::DECIMALS) as u32);
82                (Some(divisor), None)
83            }
84            Ordering::Less => {
85                let multiplier = 10u128.pow((Convert::DECIMALS - DECIMALS) as u32);
86                (None, Some(multiplier))
87            }
88            Ordering::Equal => (None, None),
89        };
90        let convert_to = |value: Self| -> Option<u64> {
91            match (&divisor, &multiplier) {
92                (Some(divisor), _) => (value / *divisor).try_into().ok(),
93                (_, Some(multiplier)) => value.checked_mul(*multiplier)?.try_into().ok(),
94                _ => value.try_into().ok(),
95            }
96        };
97        let convert_from = |value: u64| -> Option<Self> {
98            let value: Self = value.into();
99            match (&divisor, &multiplier) {
100                (Some(divisor), _) => value.checked_mul(*divisor),
101                (_, Some(multiplier)) => Some(value / *multiplier),
102                _ => Some(value),
103            }
104        };
105        let ans = FixedPointOps::<{ Convert::DECIMALS }>::checked_pow_fixed(
106            &convert_to(*self)?,
107            &convert_to(*exponent)?,
108        )?;
109        convert_from(ans)
110    }
111}
112
113/// Fixed-point decimal type.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
115pub struct Fixed<T, const DECIMALS: u8>(T);
116
117impl<T, const DECIMALS: u8> Fixed<T, DECIMALS> {
118    /// Get the internal integer representation.
119    pub fn get(&self) -> &T {
120        &self.0
121    }
122
123    /// Create a new decimal from the inner representation.
124    #[inline]
125    pub fn from_inner(inner: T) -> Self {
126        Self(inner)
127    }
128
129    /// Get the inner value.
130    #[inline]
131    pub fn into_inner(self) -> T {
132        self.0
133    }
134}
135
136impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Fixed<T, DECIMALS> {
137    /// The unit value.
138    pub const ONE: Fixed<T, DECIMALS> = Fixed(FixedPointOps::UNIT);
139    /// The decimals.
140    pub const DECIMALS: u8 = DECIMALS;
141
142    /// Checked pow.
143    pub fn checked_pow(&self, exponent: &Self) -> Option<Self> {
144        let inner = self.0.checked_pow_fixed(&exponent.0)?;
145        Some(Self(inner))
146    }
147}
148
149impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Add for Fixed<T, DECIMALS> {
150    type Output = Self;
151
152    fn add(self, rhs: Self) -> Self::Output {
153        Self(self.0.add(rhs.0))
154    }
155}
156
157impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> CheckedAdd for Fixed<T, DECIMALS> {
158    fn checked_add(&self, v: &Self) -> Option<Self> {
159        Some(Self(self.0.checked_add(&v.0)?))
160    }
161}
162
163impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Mul for Fixed<T, DECIMALS> {
164    type Output = Self;
165
166    fn mul(self, rhs: Self) -> Self::Output {
167        self.checked_mul(&rhs).expect("invalid multiplication")
168    }
169}
170
171impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> CheckedMul for Fixed<T, DECIMALS> {
172    fn checked_mul(&self, v: &Self) -> Option<Self> {
173        Some(Self(self.0.checked_mul_div(&v.0, &Self::ONE.0)?))
174    }
175}
176
177impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Zero for Fixed<T, DECIMALS> {
178    fn zero() -> Self {
179        Self(T::zero())
180    }
181
182    fn is_zero(&self) -> bool {
183        self.0.is_zero()
184    }
185}
186
187impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> One for Fixed<T, DECIMALS> {
188    fn one() -> Self {
189        Self::ONE
190    }
191
192    fn is_one(&self) -> bool
193    where
194        Self: PartialEq,
195    {
196        self.0 == Self::ONE.0
197    }
198}
199
200/// Decimal type with `9` decimals and backed by [`u64`]
201pub type U64D9 = Fixed<u64, 9>;
202
203#[cfg(feature = "u128")]
204/// Decimal type with `20` decimals and backed by [`u128`]
205pub type U128D20 = Fixed<u128, 20>;
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn basic() {
213        let x = U64D9::from_inner(12_800_000_000);
214        let y = U64D9::from_inner(25_600_000_001);
215        assert_eq!(x * y, U64D9::from_inner(327_680_000_012));
216    }
217
218    #[test]
219    fn pow() {
220        let x = U64D9::from_inner(123_456 * 100_000_000);
221        let exp = U64D9::from_inner(11 * 100_000_000);
222        let ans = x.checked_pow(&exp).unwrap();
223        assert_eq!(ans, U64D9::from_inner(31670982733137));
224    }
225
226    #[cfg(feature = "u128")]
227    #[test]
228    fn basic_u128() {
229        let x = U128D20::from_inner(128 * U128D20::ONE.0);
230        let y = U128D20::from_inner(256 * U128D20::ONE.0 + 1);
231        assert_eq!(
232            x * y,
233            U128D20::from_inner(3_276_800_000_000_000_000_000_128)
234        );
235    }
236
237    #[cfg(feature = "u128")]
238    #[test]
239    fn pow_u128() {
240        let x = U128D20::from_inner(123_456 * U128D20::ONE.0 / 10);
241        let exp = U128D20::from_inner(11 * U128D20::ONE.0 / 10);
242        let ans = x.checked_pow(&exp).unwrap();
243        assert_eq!(ans, U128D20::from_inner(3167098273313700000000000));
244    }
245}