1use std::ops::{Add, Mul};
2
3use num_traits::{CheckedAdd, CheckedMul, One, Zero};
4
5use crate::num::{MulDiv, Num};
6
7pub trait FixedPointOps<const DECIMALS: u8>: MulDiv + Num {
9 const UNIT: Self;
11
12 fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self>;
14}
15
16impl<const DECIMALS: u8> FixedPointOps<DECIMALS> for u64 {
17 const UNIT: Self = 10u64.pow(DECIMALS as u32);
18
19 #[allow(clippy::arithmetic_side_effects)]
25 fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self> {
26 use rust_decimal::{Decimal, MathematicalOps};
27
28 let unit = <Self as FixedPointOps<DECIMALS>>::UNIT;
29 if *exponent % unit == 0 {
30 let exp = exponent / unit;
31 let mut ans = Fixed::<Self, DECIMALS>::one();
33 let base = Fixed::<Self, DECIMALS>::from_inner(*self);
34 for _ in 0..exp {
35 ans = ans.checked_mul(&base)?;
36 }
37 return Some(ans.0);
38 }
39
40 if DECIMALS > 28 {
42 return None;
43 }
44 let value = Decimal::new((*self).try_into().ok()?, DECIMALS as u32);
45 let exponent = Decimal::new((*exponent).try_into().ok()?, DECIMALS as u32);
46 let mut ans = value.checked_powd(exponent)?;
47 ans.rescale(DECIMALS as u32);
48 ans.mantissa().try_into().ok()
49 }
50}
51
52#[cfg(feature = "u128")]
53impl<const DECIMALS: u8> FixedPointOps<DECIMALS> for u128 {
54 const UNIT: Self = 10u128.pow(DECIMALS as u32);
55
56 #[allow(clippy::arithmetic_side_effects)]
62 fn checked_pow_fixed(&self, exponent: &Self) -> Option<Self> {
63 use std::cmp::Ordering;
64
65 let unit = <Self as FixedPointOps<DECIMALS>>::UNIT;
66 if *exponent % unit == 0 {
67 let exp = exponent / unit;
68 let mut ans = Fixed::<Self, DECIMALS>::one();
70 let base = Fixed::<Self, DECIMALS>::from_inner(*self);
71 for _ in 0..exp {
72 ans = ans.checked_mul(&base)?;
73 }
74 return Some(ans.0);
75 }
76
77 type Convert = U64D9;
78
79 let (divisor, multiplier) = match DECIMALS.cmp(&U64D9::DECIMALS) {
80 Ordering::Greater => {
81 let divisor = 10u128.pow((DECIMALS - Convert::DECIMALS) as u32);
82 (Some(divisor), None)
83 }
84 Ordering::Less => {
85 let multiplier = 10u128.pow((Convert::DECIMALS - DECIMALS) as u32);
86 (None, Some(multiplier))
87 }
88 Ordering::Equal => (None, None),
89 };
90 let convert_to = |value: Self| -> Option<u64> {
91 match (&divisor, &multiplier) {
92 (Some(divisor), _) => (value / *divisor).try_into().ok(),
93 (_, Some(multiplier)) => value.checked_mul(*multiplier)?.try_into().ok(),
94 _ => value.try_into().ok(),
95 }
96 };
97 let convert_from = |value: u64| -> Option<Self> {
98 let value: Self = value.into();
99 match (&divisor, &multiplier) {
100 (Some(divisor), _) => value.checked_mul(*divisor),
101 (_, Some(multiplier)) => Some(value / *multiplier),
102 _ => Some(value),
103 }
104 };
105 let ans = FixedPointOps::<{ Convert::DECIMALS }>::checked_pow_fixed(
106 &convert_to(*self)?,
107 &convert_to(*exponent)?,
108 )?;
109 convert_from(ans)
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
115pub struct Fixed<T, const DECIMALS: u8>(T);
116
117impl<T, const DECIMALS: u8> Fixed<T, DECIMALS> {
118 pub fn get(&self) -> &T {
120 &self.0
121 }
122
123 #[inline]
125 pub fn from_inner(inner: T) -> Self {
126 Self(inner)
127 }
128
129 #[inline]
131 pub fn into_inner(self) -> T {
132 self.0
133 }
134}
135
136impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Fixed<T, DECIMALS> {
137 pub const ONE: Fixed<T, DECIMALS> = Fixed(FixedPointOps::UNIT);
139 pub const DECIMALS: u8 = DECIMALS;
141
142 pub fn checked_pow(&self, exponent: &Self) -> Option<Self> {
144 let inner = self.0.checked_pow_fixed(&exponent.0)?;
145 Some(Self(inner))
146 }
147}
148
149impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Add for Fixed<T, DECIMALS> {
150 type Output = Self;
151
152 fn add(self, rhs: Self) -> Self::Output {
153 Self(self.0.add(rhs.0))
154 }
155}
156
157impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> CheckedAdd for Fixed<T, DECIMALS> {
158 fn checked_add(&self, v: &Self) -> Option<Self> {
159 Some(Self(self.0.checked_add(&v.0)?))
160 }
161}
162
163impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Mul for Fixed<T, DECIMALS> {
164 type Output = Self;
165
166 fn mul(self, rhs: Self) -> Self::Output {
167 self.checked_mul(&rhs).expect("invalid multiplication")
168 }
169}
170
171impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> CheckedMul for Fixed<T, DECIMALS> {
172 fn checked_mul(&self, v: &Self) -> Option<Self> {
173 Some(Self(self.0.checked_mul_div(&v.0, &Self::ONE.0)?))
174 }
175}
176
177impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> Zero for Fixed<T, DECIMALS> {
178 fn zero() -> Self {
179 Self(T::zero())
180 }
181
182 fn is_zero(&self) -> bool {
183 self.0.is_zero()
184 }
185}
186
187impl<T: FixedPointOps<DECIMALS>, const DECIMALS: u8> One for Fixed<T, DECIMALS> {
188 fn one() -> Self {
189 Self::ONE
190 }
191
192 fn is_one(&self) -> bool
193 where
194 Self: PartialEq,
195 {
196 self.0 == Self::ONE.0
197 }
198}
199
200pub type U64D9 = Fixed<u64, 9>;
202
203#[cfg(feature = "u128")]
204pub type U128D20 = Fixed<u128, 20>;
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn basic() {
213 let x = U64D9::from_inner(12_800_000_000);
214 let y = U64D9::from_inner(25_600_000_001);
215 assert_eq!(x * y, U64D9::from_inner(327_680_000_012));
216 }
217
218 #[test]
219 fn pow() {
220 let x = U64D9::from_inner(123_456 * 100_000_000);
221 let exp = U64D9::from_inner(11 * 100_000_000);
222 let ans = x.checked_pow(&exp).unwrap();
223 assert_eq!(ans, U64D9::from_inner(31670982733137));
224 }
225
226 #[cfg(feature = "u128")]
227 #[test]
228 fn basic_u128() {
229 let x = U128D20::from_inner(128 * U128D20::ONE.0);
230 let y = U128D20::from_inner(256 * U128D20::ONE.0 + 1);
231 assert_eq!(
232 x * y,
233 U128D20::from_inner(3_276_800_000_000_000_000_000_128)
234 );
235 }
236
237 #[cfg(feature = "u128")]
238 #[test]
239 fn pow_u128() {
240 let x = U128D20::from_inner(123_456 * U128D20::ONE.0 / 10);
241 let exp = U128D20::from_inner(11 * U128D20::ONE.0 / 10);
242 let ans = x.checked_pow(&exp).unwrap();
243 assert_eq!(ans, U128D20::from_inner(3167098273313700000000000));
244 }
245}