1
// This file is part of Substrate.
2

            
3
// Copyright (C) Parity Technologies (UK) Ltd.
4
// SPDX-License-Identifier: Apache-2.0
5

            
6
// Licensed under the Apache License, Version 2.0 (the "License");
7
// you may not use this file except in compliance with the License.
8
// You may obtain a copy of the License at
9
//
10
// 	http://www.apache.org/licenses/LICENSE-2.0
11
//
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
17

            
18
use crate::{biguint::BigUint, helpers_128bit, Rounding};
19
use core::cmp::Ordering;
20
use num_traits::{Bounded, One, Zero};
21

            
22
/// A wrapper for any rational number with infinitely large numerator and denominator.
23
///
24
/// This type exists to facilitate `cmp` operation
25
/// on values like `a/b < c/d` where `a, b, c, d` are all `BigUint`.
26
#[derive(Clone, Default, Eq)]
27
pub struct RationalInfinite(BigUint, BigUint);
28

            
29
impl RationalInfinite {
30
	/// Return the numerator reference.
31
	pub fn n(&self) -> &BigUint {
32
		&self.0
33
	}
34

            
35
	/// Return the denominator reference.
36
	pub fn d(&self) -> &BigUint {
37
		&self.1
38
	}
39

            
40
	/// Build from a raw `n/d`.
41
	pub fn from(n: BigUint, d: BigUint) -> Self {
42
		Self(n, d.max(BigUint::one()))
43
	}
44

            
45
	/// Zero.
46
	pub fn zero() -> Self {
47
		Self(BigUint::zero(), BigUint::one())
48
	}
49

            
50
	/// One.
51
	pub fn one() -> Self {
52
		Self(BigUint::one(), BigUint::one())
53
	}
54
}
55

            
56
impl PartialOrd for RationalInfinite {
57
	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
58
		Some(self.cmp(other))
59
	}
60
}
61

            
62
impl Ord for RationalInfinite {
63
	fn cmp(&self, other: &Self) -> Ordering {
64
		// handle some edge cases.
65
		if self.d() == other.d() {
66
			self.n().cmp(other.n())
67
		} else if self.d().is_zero() {
68
			Ordering::Greater
69
		} else if other.d().is_zero() {
70
			Ordering::Less
71
		} else {
72
			// (a/b) cmp (c/d) => (a*d) cmp (c*b)
73
			self.n().clone().mul(other.d()).cmp(&other.n().clone().mul(self.d()))
74
		}
75
	}
76
}
77

            
78
impl PartialEq for RationalInfinite {
79
	fn eq(&self, other: &Self) -> bool {
80
		self.cmp(other) == Ordering::Equal
81
	}
82
}
83

            
84
impl From<Rational128> for RationalInfinite {
85
	fn from(t: Rational128) -> Self {
86
		Self(t.0.into(), t.1.into())
87
	}
88
}
89

            
90
/// A wrapper for any rational number with a 128 bit numerator and denominator.
91
#[derive(Clone, Copy, Default, Eq)]
92
pub struct Rational128(u128, u128);
93

            
94
#[cfg(feature = "std")]
95
impl core::fmt::Debug for Rational128 {
96
	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97
		write!(f, "Rational128({} / {} ≈ {:.8})", self.0, self.1, self.0 as f64 / self.1 as f64)
98
	}
99
}
100

            
101
#[cfg(not(feature = "std"))]
102
impl core::fmt::Debug for Rational128 {
103
	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
104
		write!(f, "Rational128({} / {})", self.0, self.1)
105
	}
106
}
107

            
108
impl Rational128 {
109
	/// Zero.
110
	pub fn zero() -> Self {
111
		Self(0, 1)
112
	}
113

            
114
	/// One
115
	pub fn one() -> Self {
116
		Self(1, 1)
117
	}
118

            
119
	/// If it is zero or not
120
	pub fn is_zero(&self) -> bool {
121
		self.0.is_zero()
122
	}
123

            
124
	/// Build from a raw `n/d`.
125
	pub fn from(n: u128, d: u128) -> Self {
126
		Self(n, d.max(1))
127
	}
128

            
129
	/// Build from a raw `n/d`. This could lead to / 0 if not properly handled.
130
	pub fn from_unchecked(n: u128, d: u128) -> Self {
131
		Self(n, d)
132
	}
133

            
134
	/// Return the numerator.
135
	pub fn n(&self) -> u128 {
136
		self.0
137
	}
138

            
139
	/// Return the denominator.
140
	pub fn d(&self) -> u128 {
141
		self.1
142
	}
143

            
144
	/// Convert `self` to a similar rational number where denominator is the given `den`.
145
	//
146
	/// This only returns if the result is accurate. `None` is returned if the result cannot be
147
	/// accurately calculated.
148
	pub fn to_den(self, den: u128) -> Option<Self> {
149
		if den == self.1 {
150
			Some(self)
151
		} else {
152
			helpers_128bit::multiply_by_rational_with_rounding(
153
				self.0,
154
				den,
155
				self.1,
156
				Rounding::NearestPrefDown,
157
			)
158
			.map(|n| Self(n, den))
159
		}
160
	}
161

            
162
	/// Get the least common divisor of `self` and `other`.
163
	///
164
	/// This only returns if the result is accurate. `None` is returned if the result cannot be
165
	/// accurately calculated.
166
	pub fn lcm(&self, other: &Self) -> Option<u128> {
167
		// this should be tested better: two large numbers that are almost the same.
168
		if self.1 == other.1 {
169
			return Some(self.1)
170
		}
171
		let g = helpers_128bit::gcd(self.1, other.1);
172
		helpers_128bit::multiply_by_rational_with_rounding(
173
			self.1,
174
			other.1,
175
			g,
176
			Rounding::NearestPrefDown,
177
		)
178
	}
179

            
180
	/// A saturating add that assumes `self` and `other` have the same denominator.
181
	pub fn lazy_saturating_add(self, other: Self) -> Self {
182
		if other.is_zero() {
183
			self
184
		} else {
185
			Self(self.0.saturating_add(other.0), self.1)
186
		}
187
	}
188

            
189
	/// A saturating subtraction that assumes `self` and `other` have the same denominator.
190
	pub fn lazy_saturating_sub(self, other: Self) -> Self {
191
		if other.is_zero() {
192
			self
193
		} else {
194
			Self(self.0.saturating_sub(other.0), self.1)
195
		}
196
	}
197

            
198
	/// Addition. Simply tries to unify the denominators and add the numerators.
199
	///
200
	/// Overflow might happen during any of the steps. Error is returned in such cases.
201
	pub fn checked_add(self, other: Self) -> Result<Self, &'static str> {
202
		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
203
		let self_scaled =
204
			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
205
		let other_scaled =
206
			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
207
		let n = self_scaled
208
			.0
209
			.checked_add(other_scaled.0)
210
			.ok_or("overflow while adding numerators")?;
211
		Ok(Self(n, self_scaled.1))
212
	}
213

            
214
	/// Subtraction. Simply tries to unify the denominators and subtract the numerators.
215
	///
216
	/// Overflow might happen during any of the steps. None is returned in such cases.
217
	pub fn checked_sub(self, other: Self) -> Result<Self, &'static str> {
218
		let lcm = self.lcm(&other).ok_or(0).map_err(|_| "failed to scale to denominator")?;
219
		let self_scaled =
220
			self.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
221
		let other_scaled =
222
			other.to_den(lcm).ok_or(0).map_err(|_| "failed to scale to denominator")?;
223

            
224
		let n = self_scaled
225
			.0
226
			.checked_sub(other_scaled.0)
227
			.ok_or("overflow while subtracting numerators")?;
228
		Ok(Self(n, self_scaled.1))
229
	}
230
}
231

            
232
impl Bounded for Rational128 {
233
	fn min_value() -> Self {
234
		Self(0, 1)
235
	}
236

            
237
	fn max_value() -> Self {
238
		Self(Bounded::max_value(), 1)
239
	}
240
}
241

            
242
impl<T: Into<u128>> From<T> for Rational128 {
243
	fn from(t: T) -> Self {
244
		Self::from(t.into(), 1)
245
	}
246
}
247

            
248
impl PartialOrd for Rational128 {
249
	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
250
		Some(self.cmp(other))
251
	}
252
}
253

            
254
impl Ord for Rational128 {
255
	fn cmp(&self, other: &Self) -> Ordering {
256
		// handle some edge cases.
257
		if self.1 == other.1 {
258
			self.0.cmp(&other.0)
259
		} else if self.1.is_zero() {
260
			Ordering::Greater
261
		} else if other.1.is_zero() {
262
			Ordering::Less
263
		} else {
264
			// Don't even compute gcd.
265
			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
266
			let other_n =
267
				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
268
			self_n.cmp(&other_n)
269
		}
270
	}
271
}
272

            
273
impl PartialEq for Rational128 {
274
	fn eq(&self, other: &Self) -> bool {
275
		// handle some edge cases.
276
		if self.1 == other.1 {
277
			self.0.eq(&other.0)
278
		} else {
279
			let self_n = helpers_128bit::to_big_uint(self.0) * helpers_128bit::to_big_uint(other.1);
280
			let other_n =
281
				helpers_128bit::to_big_uint(other.0) * helpers_128bit::to_big_uint(self.1);
282
			self_n.eq(&other_n)
283
		}
284
	}
285
}
286

            
287
pub trait MultiplyRational: Sized {
288
	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self>;
289
}
290

            
291
macro_rules! impl_rrm {
292
	($ulow:ty, $uhi:ty) => {
293
		impl MultiplyRational for $ulow {
294
421408
			fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
295
421408
				if d.is_zero() {
296
					return None
297
421408
				}
298
421408

            
299
421408
				let sn = (self as $uhi) * (n as $uhi);
300
421408
				let mut result = sn / (d as $uhi);
301
421408
				let remainder = (sn % (d as $uhi)) as $ulow;
302
421408
				if match r {
303
4
					Rounding::Up => remainder > 0,
304
					// cannot be `(d + 1) / 2` since `d` might be `max_value` and overflow.
305
					Rounding::NearestPrefUp => remainder >= d / 2 + d % 2,
306
					Rounding::NearestPrefDown => remainder > d / 2,
307
421404
					Rounding::Down => false,
308
				} {
309
					result = match result.checked_add(1) {
310
						Some(v) => v,
311
						None => return None,
312
					};
313
421408
				}
314
421408
				if result > (<$ulow>::max_value() as $uhi) {
315
					None
316
				} else {
317
421408
					Some(result as $ulow)
318
				}
319
421408
			}
320
		}
321
	};
322
}
323

            
324
impl_rrm!(u8, u16);
325
impl_rrm!(u16, u32);
326
impl_rrm!(u32, u64);
327
impl_rrm!(u64, u128);
328

            
329
impl MultiplyRational for u128 {
330
20
	fn multiply_rational(self, n: Self, d: Self, r: Rounding) -> Option<Self> {
331
20
		crate::helpers_128bit::multiply_by_rational_with_rounding(self, n, d, r)
332
20
	}
333
}
334

            
335
#[cfg(test)]
336
mod tests {
337
	use super::{helpers_128bit::*, *};
338
	use static_assertions::const_assert;
339

            
340
	const MAX128: u128 = u128::MAX;
341
	const MAX64: u128 = u64::MAX as u128;
342
	const MAX64_2: u128 = 2 * u64::MAX as u128;
343

            
344
	fn r(p: u128, q: u128) -> Rational128 {
345
		Rational128(p, q)
346
	}
347

            
348
	fn mul_div(a: u128, b: u128, c: u128) -> u128 {
349
		use primitive_types::U256;
350
		if a.is_zero() {
351
			return Zero::zero()
352
		}
353
		let c = c.max(1);
354

            
355
		// e for extended
356
		let ae: U256 = a.into();
357
		let be: U256 = b.into();
358
		let ce: U256 = c.into();
359

            
360
		let r = ae * be / ce;
361
		if r > u128::max_value().into() {
362
			a
363
		} else {
364
			r.as_u128()
365
		}
366
	}
367

            
368
	#[test]
369
	fn truth_value_function_works() {
370
		assert_eq!(mul_div(2u128.pow(100), 8, 4), 2u128.pow(101));
371
		assert_eq!(mul_div(2u128.pow(100), 4, 8), 2u128.pow(99));
372

            
373
		// and it returns a if result cannot fit
374
		assert_eq!(mul_div(MAX128 - 10, 2, 1), MAX128 - 10);
375
	}
376

            
377
	#[test]
378
	fn to_denom_works() {
379
		// simple up and down
380
		assert_eq!(r(1, 5).to_den(10), Some(r(2, 10)));
381
		assert_eq!(r(4, 10).to_den(5), Some(r(2, 5)));
382

            
383
		// up and down with large numbers
384
		assert_eq!(r(MAX128 - 10, MAX128).to_den(10), Some(r(10, 10)));
385
		assert_eq!(r(MAX128 / 2, MAX128).to_den(10), Some(r(5, 10)));
386

            
387
		// large to perbill. This is very well needed for npos-elections.
388
		assert_eq!(r(MAX128 / 2, MAX128).to_den(1000_000_000), Some(r(500_000_000, 1000_000_000)));
389

            
390
		// large to large
391
		assert_eq!(r(MAX128 / 2, MAX128).to_den(MAX128 / 2), Some(r(MAX128 / 4, MAX128 / 2)));
392
	}
393

            
394
	#[test]
395
	fn gdc_works() {
396
		assert_eq!(gcd(10, 5), 5);
397
		assert_eq!(gcd(7, 22), 1);
398
	}
399

            
400
	#[test]
401
	fn lcm_works() {
402
		// simple stuff
403
		assert_eq!(r(3, 10).lcm(&r(4, 15)).unwrap(), 30);
404
		assert_eq!(r(5, 30).lcm(&r(1, 7)).unwrap(), 210);
405
		assert_eq!(r(5, 30).lcm(&r(1, 10)).unwrap(), 30);
406

            
407
		// large numbers
408
		assert_eq!(r(1_000_000_000, MAX128).lcm(&r(7_000_000_000, MAX128 - 1)), None,);
409
		assert_eq!(
410
			r(1_000_000_000, MAX64).lcm(&r(7_000_000_000, MAX64 - 1)),
411
			Some(340282366920938463408034375210639556610),
412
		);
413
		const_assert!(340282366920938463408034375210639556610 < MAX128);
414
		const_assert!(340282366920938463408034375210639556610 == MAX64 * (MAX64 - 1));
415
	}
416

            
417
	#[test]
418
	fn add_works() {
419
		// works
420
		assert_eq!(r(3, 10).checked_add(r(1, 10)).unwrap(), r(2, 5));
421
		assert_eq!(r(3, 10).checked_add(r(3, 7)).unwrap(), r(51, 70));
422

            
423
		// errors
424
		assert_eq!(
425
			r(1, MAX128).checked_add(r(1, MAX128 - 1)),
426
			Err("failed to scale to denominator"),
427
		);
428
		assert_eq!(
429
			r(7, MAX128).checked_add(r(MAX128, MAX128)),
430
			Err("overflow while adding numerators"),
431
		);
432
		assert_eq!(
433
			r(MAX128, MAX128).checked_add(r(MAX128, MAX128)),
434
			Err("overflow while adding numerators"),
435
		);
436
	}
437

            
438
	#[test]
439
	fn sub_works() {
440
		// works
441
		assert_eq!(r(3, 10).checked_sub(r(1, 10)).unwrap(), r(1, 5));
442
		assert_eq!(r(6, 10).checked_sub(r(3, 7)).unwrap(), r(12, 70));
443

            
444
		// errors
445
		assert_eq!(
446
			r(2, MAX128).checked_sub(r(1, MAX128 - 1)),
447
			Err("failed to scale to denominator"),
448
		);
449
		assert_eq!(
450
			r(7, MAX128).checked_sub(r(MAX128, MAX128)),
451
			Err("overflow while subtracting numerators"),
452
		);
453
		assert_eq!(r(1, 10).checked_sub(r(2, 10)), Err("overflow while subtracting numerators"));
454
	}
455

            
456
	#[test]
457
	fn ordering_and_eq_works() {
458
		assert!(r(1, 2) > r(1, 3));
459
		assert!(r(1, 2) > r(2, 6));
460

            
461
		assert!(r(1, 2) < r(6, 6));
462
		assert!(r(2, 1) > r(2, 6));
463

            
464
		assert!(r(5, 10) == r(1, 2));
465
		assert!(r(1, 2) == r(1, 2));
466

            
467
		assert!(r(1, 1490000000000200000) > r(1, 1490000000000200001));
468
	}
469

            
470
	#[test]
471
	fn multiply_by_rational_with_rounding_works() {
472
		assert_eq!(multiply_by_rational_with_rounding(7, 2, 3, Rounding::Down).unwrap(), 7 * 2 / 3);
473
		assert_eq!(
474
			multiply_by_rational_with_rounding(7, 20, 30, Rounding::Down).unwrap(),
475
			7 * 2 / 3
476
		);
477
		assert_eq!(
478
			multiply_by_rational_with_rounding(20, 7, 30, Rounding::Down).unwrap(),
479
			7 * 2 / 3
480
		);
481

            
482
		assert_eq!(
483
			// MAX128 % 3 == 0
484
			multiply_by_rational_with_rounding(MAX128, 2, 3, Rounding::Down).unwrap(),
485
			MAX128 / 3 * 2,
486
		);
487
		assert_eq!(
488
			// MAX128 % 7 == 3
489
			multiply_by_rational_with_rounding(MAX128, 5, 7, Rounding::Down).unwrap(),
490
			(MAX128 / 7 * 5) + (3 * 5 / 7),
491
		);
492
		assert_eq!(
493
			// MAX128 % 7 == 3
494
			multiply_by_rational_with_rounding(MAX128, 11, 13, Rounding::Down).unwrap(),
495
			(MAX128 / 13 * 11) + (8 * 11 / 13),
496
		);
497
		assert_eq!(
498
			// MAX128 % 1000 == 455
499
			multiply_by_rational_with_rounding(MAX128, 555, 1000, Rounding::Down).unwrap(),
500
			(MAX128 / 1000 * 555) + (455 * 555 / 1000),
501
		);
502

            
503
		assert_eq!(
504
			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64, MAX64, Rounding::Down)
505
				.unwrap(),
506
			2 * MAX64 - 1
507
		);
508
		assert_eq!(
509
			multiply_by_rational_with_rounding(2 * MAX64 - 1, MAX64 - 1, MAX64, Rounding::Down)
510
				.unwrap(),
511
			2 * MAX64 - 3
512
		);
513

            
514
		assert_eq!(
515
			multiply_by_rational_with_rounding(MAX64 + 100, MAX64_2, MAX64_2 / 2, Rounding::Down)
516
				.unwrap(),
517
			(MAX64 + 100) * 2,
518
		);
519
		assert_eq!(
520
			multiply_by_rational_with_rounding(
521
				MAX64 + 100,
522
				MAX64_2 / 100,
523
				MAX64_2 / 200,
524
				Rounding::Down
525
			)
526
			.unwrap(),
527
			(MAX64 + 100) * 2,
528
		);
529

            
530
		assert_eq!(
531
			multiply_by_rational_with_rounding(
532
				2u128.pow(66) - 1,
533
				2u128.pow(65) - 1,
534
				2u128.pow(65),
535
				Rounding::Down
536
			)
537
			.unwrap(),
538
			73786976294838206461,
539
		);
540
		assert_eq!(
541
			multiply_by_rational_with_rounding(1_000_000_000, MAX128 / 8, MAX128 / 2, Rounding::Up)
542
				.unwrap(),
543
			250000000
544
		);
545

            
546
		assert_eq!(
547
			multiply_by_rational_with_rounding(
548
				29459999999999999988000u128,
549
				1000000000000000000u128,
550
				10000000000000000000u128,
551
				Rounding::Down
552
			)
553
			.unwrap(),
554
			2945999999999999998800u128
555
		);
556
	}
557

            
558
	#[test]
559
	fn multiply_by_rational_with_rounding_a_b_are_interchangeable() {
560
		assert_eq!(
561
			multiply_by_rational_with_rounding(10, MAX128, MAX128 / 2, Rounding::NearestPrefDown),
562
			Some(20)
563
		);
564
		assert_eq!(
565
			multiply_by_rational_with_rounding(MAX128, 10, MAX128 / 2, Rounding::NearestPrefDown),
566
			Some(20)
567
		);
568
	}
569

            
570
	#[test]
571
	#[ignore]
572
	fn multiply_by_rational_with_rounding_fuzzed_equation() {
573
		assert_eq!(
574
			multiply_by_rational_with_rounding(
575
				154742576605164960401588224,
576
				9223376310179529214,
577
				549756068598,
578
				Rounding::NearestPrefDown
579
			),
580
			Some(2596149632101417846585204209223679)
581
		);
582
	}
583
}