1
// This file is part of Substrate.
2

            
3
// Copyright (C) Parity Technologies (UK) Ltd.
4
// SPDX-License-Identifier: Apache-2.0
5

            
6
// Licensed under the Apache License, Version 2.0 (the "License");
7
// you may not use this file except in compliance with the License.
8
// You may obtain a copy of the License at
9
//
10
// 	http://www.apache.org/licenses/LICENSE-2.0
11
//
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
17

            
18
//! Minimal fixed point arithmetic primitives and types for runtime.
19

            
20
#![cfg_attr(not(feature = "std"), no_std)]
21

            
22
extern crate alloc;
23

            
24
/// Copied from `sp-runtime` and documented there.
25
#[macro_export]
26
macro_rules! assert_eq_error_rate {
27
	($x:expr, $y:expr, $error:expr $(,)?) => {
28
		assert!(
29
			($x) >= (($y) - ($error)) && ($x) <= (($y) + ($error)),
30
			"{:?} != {:?} (with error rate {:?})",
31
			$x,
32
			$y,
33
			$error,
34
		);
35
	};
36
}
37

            
38
pub mod biguint;
39
pub mod fixed_point;
40
pub mod helpers_128bit;
41
pub mod per_things;
42
pub mod rational;
43
pub mod traits;
44

            
45
pub use fixed_point::{
46
	FixedI128, FixedI64, FixedPointNumber, FixedPointOperand, FixedU128, FixedU64,
47
};
48
pub use per_things::{
49
	InnerOf, MultiplyArg, PerThing, PerU16, Perbill, Percent, Permill, Perquintill, RationalArg,
50
	ReciprocalArg, Rounding, SignedRounding, UpperOf,
51
};
52
pub use rational::{MultiplyRational, Rational128, RationalInfinite};
53

            
54
use alloc::vec::Vec;
55
use core::{cmp::Ordering, fmt::Debug};
56
use traits::{BaseArithmetic, One, SaturatedConversion, Unsigned, Zero};
57

            
58
use codec::{Decode, Encode, MaxEncodedLen};
59
use scale_info::TypeInfo;
60

            
61
#[cfg(feature = "serde")]
62
use serde::{Deserialize, Serialize};
63

            
64
/// Arithmetic errors.
65
#[derive(Eq, PartialEq, Clone, Copy, Encode, Decode, Debug, TypeInfo, MaxEncodedLen)]
66
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67
pub enum ArithmeticError {
68
	/// Underflow.
69
	Underflow,
70
	/// Overflow.
71
	Overflow,
72
	/// Division by zero.
73
	DivisionByZero,
74
}
75

            
76
impl From<ArithmeticError> for &'static str {
77
2
	fn from(e: ArithmeticError) -> &'static str {
78
2
		match e {
79
2
			ArithmeticError::Underflow => "An underflow would occur",
80
			ArithmeticError::Overflow => "An overflow would occur",
81
			ArithmeticError::DivisionByZero => "Division by zero",
82
		}
83
2
	}
84
}
85

            
86
/// Trait for comparing two numbers with an threshold.
87
///
88
/// Returns:
89
/// - `Ordering::Greater` if `self` is greater than `other + threshold`.
90
/// - `Ordering::Less` if `self` is less than `other - threshold`.
91
/// - `Ordering::Equal` otherwise.
92
pub trait ThresholdOrd<T> {
93
	/// Compare if `self` is `threshold` greater or less than `other`.
94
	fn tcmp(&self, other: &T, threshold: T) -> Ordering;
95
}
96

            
97
impl<T> ThresholdOrd<T> for T
98
where
99
	T: Ord + PartialOrd + Copy + Clone + traits::Zero + traits::Saturating,
100
{
101
	fn tcmp(&self, other: &T, threshold: T) -> Ordering {
102
		// early exit.
103
		if threshold.is_zero() {
104
			return self.cmp(other);
105
		}
106

            
107
		let upper_bound = other.saturating_add(threshold);
108
		let lower_bound = other.saturating_sub(threshold);
109

            
110
		if upper_bound <= lower_bound {
111
			// defensive only. Can never happen.
112
			self.cmp(other)
113
		} else {
114
			// upper_bound is guaranteed now to be bigger than lower.
115
			match (self.cmp(&lower_bound), self.cmp(&upper_bound)) {
116
				(Ordering::Greater, Ordering::Greater) => Ordering::Greater,
117
				(Ordering::Less, Ordering::Less) => Ordering::Less,
118
				_ => Ordering::Equal,
119
			}
120
		}
121
	}
122
}
123

            
124
/// A collection-like object that is made of values of type `T` and can normalize its individual
125
/// values around a centric point.
126
///
127
/// Note that the order of items in the collection may affect the result.
128
pub trait Normalizable<T> {
129
	/// Normalize self around `targeted_sum`.
130
	///
131
	/// Only returns `Ok` if the new sum of results is guaranteed to be equal to `targeted_sum`.
132
	/// Else, returns an error explaining why it failed to do so.
133
	fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
134
}
135

            
136
macro_rules! impl_normalize_for_numeric {
137
	($($numeric:ty),*) => {
138
		$(
139
			impl Normalizable<$numeric> for Vec<$numeric> {
140
				fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
141
					normalize(self.as_ref(), targeted_sum)
142
				}
143
			}
144
		)*
145
	};
146
}
147

            
148
impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
149

            
150
impl<P: PerThing> Normalizable<P> for Vec<P> {
151
	fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
152
		let uppers = self.iter().map(|p| <UpperOf<P>>::from(p.deconstruct())).collect::<Vec<_>>();
153

            
154
		let normalized =
155
			normalize(uppers.as_ref(), <UpperOf<P>>::from(targeted_sum.deconstruct()))?;
156

            
157
		Ok(normalized
158
			.into_iter()
159
			.map(|i: UpperOf<P>| P::from_parts(i.saturated_into::<P::Inner>()))
160
			.collect())
161
	}
162
}
163

            
164
/// Normalize `input` so that the sum of all elements reaches `targeted_sum`.
165
///
166
/// This implementation is currently in a balanced position between being performant and accurate.
167
///
168
/// 1. We prefer storing original indices, and sorting the `input` only once. This will save the
169
///    cost of sorting per round at the cost of a little bit of memory.
170
/// 2. The granularity of increment/decrements is determined by the number of elements in `input`
171
///    and their sum difference with `targeted_sum`, namely `diff = diff(sum(input), target_sum)`.
172
///    This value is then distributed into `per_round = diff / input.len()` and `leftover = diff %
173
///    round`. First, per_round is applied to all elements of input, and then we move to leftover,
174
///    in which case we add/subtract 1 by 1 until `leftover` is depleted.
175
///
176
/// When the sum is less than the target, the above approach always holds. In this case, then each
177
/// individual element is also less than target. Thus, by adding `per_round` to each item, neither
178
/// of them can overflow the numeric bound of `T`. In fact, neither of the can go beyond
179
/// `target_sum`*.
180
///
181
/// If sum is more than target, there is small twist. The subtraction of `per_round`
182
/// form each element might go below zero. In this case, we saturate and add the error to the
183
/// `leftover` value. This ensures that the result will always stay accurate, yet it might cause the
184
/// execution to become increasingly slow, since leftovers are applied one by one.
185
///
186
/// All in all, the complicated case above is rare to happen in most use cases within this repo ,
187
/// hence we opt for it due to its simplicity.
188
///
189
/// This function will return an error is if length of `input` cannot fit in `T`, or if `sum(input)`
190
/// cannot fit inside `T`.
191
///
192
/// * This proof is used in the implementation as well.
193
pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
194
where
195
	T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
196
{
197
	// compute sum and return error if failed.
198
	let mut sum = T::zero();
199
	for t in input.iter() {
200
		sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
201
	}
202

            
203
	// convert count and return error if failed.
204
	let count = input.len();
205
	let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
206

            
207
	// Nothing to do here.
208
	if count.is_zero() {
209
		return Ok(Vec::<T>::new());
210
	}
211

            
212
	let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
213
	if diff.is_zero() {
214
		return Ok(input.to_vec());
215
	}
216

            
217
	let needs_bump = targeted_sum > sum;
218
	let per_round = diff / count_t;
219
	let mut leftover = diff % count_t;
220

            
221
	// sort output once based on diff. This will require more data transfer and saving original
222
	// index, but we sort only twice instead: once now and once at the very end.
223
	let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
224
	output_with_idx.sort_by_key(|x| x.1);
225

            
226
	if needs_bump {
227
		// must increase the values a bit. Bump from the min element. Index of minimum is now zero
228
		// because we did a sort. If at any point the min goes greater or equal the `max_threshold`,
229
		// we move to the next minimum.
230
		let mut min_index = 0;
231
		// at this threshold we move to next index.
232
		let threshold = targeted_sum / count_t;
233

            
234
		if !per_round.is_zero() {
235
			for _ in 0..count {
236
				output_with_idx[min_index].1 = output_with_idx[min_index]
237
					.1
238
					.checked_add(&per_round)
239
					.expect("Proof provided in the module doc; qed.");
240
				if output_with_idx[min_index].1 >= threshold {
241
					min_index += 1;
242
					min_index %= count;
243
				}
244
			}
245
		}
246

            
247
		// continue with the previous min_index
248
		while !leftover.is_zero() {
249
			output_with_idx[min_index].1 = output_with_idx[min_index]
250
				.1
251
				.checked_add(&T::one())
252
				.expect("Proof provided in the module doc; qed.");
253
			if output_with_idx[min_index].1 >= threshold {
254
				min_index += 1;
255
				min_index %= count;
256
			}
257
			leftover -= One::one();
258
		}
259
	} else {
260
		// must decrease the stakes a bit. decrement from the max element. index of maximum is now
261
		// last. if at any point the max goes less or equal the `min_threshold`, we move to the next
262
		// maximum.
263
		let mut max_index = count - 1;
264
		// at this threshold we move to next index.
265
		let threshold = output_with_idx
266
			.first()
267
			.expect("length of input is greater than zero; it must have a first; qed")
268
			.1;
269

            
270
		if !per_round.is_zero() {
271
			for _ in 0..count {
272
				output_with_idx[max_index].1 =
273
					output_with_idx[max_index].1.checked_sub(&per_round).unwrap_or_else(|| {
274
						let remainder = per_round - output_with_idx[max_index].1;
275
						leftover += remainder;
276
						output_with_idx[max_index].1.saturating_sub(per_round)
277
					});
278
				if output_with_idx[max_index].1 <= threshold {
279
					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
280
				}
281
			}
282
		}
283

            
284
		// continue with the previous max_index
285
		while !leftover.is_zero() {
286
			if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
287
				output_with_idx[max_index].1 = next;
288
				if output_with_idx[max_index].1 <= threshold {
289
					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
290
				}
291
				leftover -= One::one();
292
			} else {
293
				max_index = max_index.checked_sub(1).unwrap_or(count - 1);
294
			}
295
		}
296
	}
297

            
298
	debug_assert_eq!(
299
		output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
300
		targeted_sum,
301
		"sum({:?}) != {:?}",
302
		output_with_idx,
303
		targeted_sum
304
	);
305

            
306
	// sort again based on the original index.
307
	output_with_idx.sort_by_key(|x| x.0);
308
	Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
309
}
310

            
311
#[cfg(test)]
312
mod normalize_tests {
313
	use super::*;
314

            
315
	#[test]
316
	fn work_for_all_types() {
317
		macro_rules! test_for {
318
			($type:ty) => {
319
				assert_eq!(
320
					normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
321
					vec![10, 10, 10, 10],
322
				);
323
			};
324
		}
325
		// it should work for all types as long as the length of vector can be converted to T.
326
		test_for!(u128);
327
		test_for!(u64);
328
		test_for!(u32);
329
		test_for!(u16);
330
		test_for!(u8);
331
	}
332

            
333
	#[test]
334
	fn fails_on_if_input_sum_large() {
335
		assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
336
		assert_eq!(normalize(vec![1u8; 256].as_ref(), 10), Err("sum of input cannot fit in `T`"));
337
	}
338

            
339
	#[test]
340
	fn does_not_fail_on_subtraction_overflow() {
341
		assert_eq!(normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(), vec![1, 9, 0]);
342
		assert_eq!(normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(), vec![0, 1, 0]);
343
	}
344

            
345
	#[test]
346
	fn works_for_vec() {
347
		assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
348
	}
349

            
350
	#[test]
351
	fn works_for_per_thing() {
352
		assert_eq!(
353
			vec![Perbill::from_percent(33), Perbill::from_percent(33), Perbill::from_percent(33)]
354
				.normalize(Perbill::one())
355
				.unwrap(),
356
			vec![
357
				Perbill::from_parts(333333334),
358
				Perbill::from_parts(333333333),
359
				Perbill::from_parts(333333333)
360
			]
361
		);
362

            
363
		assert_eq!(
364
			vec![Perbill::from_percent(20), Perbill::from_percent(15), Perbill::from_percent(30)]
365
				.normalize(Perbill::one())
366
				.unwrap(),
367
			vec![
368
				Perbill::from_parts(316666668),
369
				Perbill::from_parts(383333332),
370
				Perbill::from_parts(300000000)
371
			]
372
		);
373
	}
374

            
375
	#[test]
376
	fn can_work_for_peru16() {
377
		// Peru16 is a rather special case; since inner type is exactly the same as capacity, we
378
		// could have a situation where the sum cannot be calculated in the inner type. Calculating
379
		// using the upper type of the per_thing should assure this to be okay.
380
		assert_eq!(
381
			vec![PerU16::from_percent(40), PerU16::from_percent(40), PerU16::from_percent(40)]
382
				.normalize(PerU16::one())
383
				.unwrap(),
384
			vec![
385
				PerU16::from_parts(21845), // 33%
386
				PerU16::from_parts(21845), // 33%
387
				PerU16::from_parts(21845)  // 33%
388
			]
389
		);
390
	}
391

            
392
	#[test]
393
	fn normalize_works_all_le() {
394
		assert_eq!(normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
395

            
396
		assert_eq!(normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
397

            
398
		assert_eq!(normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
399

            
400
		assert_eq!(normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(), vec![11, 8, 11, 10]);
401

            
402
		assert_eq!(normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(), vec![11, 11, 8, 10]);
403
	}
404

            
405
	#[test]
406
	fn normalize_works_some_ge() {
407
		assert_eq!(normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(), vec![10, 11, 9, 10]);
408
	}
409

            
410
	#[test]
411
	fn always_inc_min() {
412
		assert_eq!(normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
413
		assert_eq!(normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
414
		assert_eq!(normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
415
	}
416

            
417
	#[test]
418
	fn normalize_works_all_ge() {
419
		assert_eq!(normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
420

            
421
		assert_eq!(normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(), vec![10, 10, 10, 10]);
422

            
423
		assert_eq!(normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(), vec![12, 9, 9, 10]);
424

            
425
		assert_eq!(normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(), vec![9, 12, 9, 10]);
426

            
427
		assert_eq!(normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(), vec![9, 9, 12, 10]);
428
	}
429
}
430

            
431
#[cfg(test)]
432
mod per_and_fixed_examples {
433
	use super::*;
434

            
435
	#[docify::export]
436
	#[test]
437
	fn percent_mult() {
438
		let percent = Percent::from_rational(5u32, 100u32); // aka, 5%
439
		let five_percent_of_100 = percent * 100u32; // 5% of 100 is 5.
440
		assert_eq!(five_percent_of_100, 5)
441
	}
442
	#[docify::export]
443
	#[test]
444
	fn perbill_example() {
445
		let p = Perbill::from_percent(80);
446
		// 800000000 bil, or a representative of 0.800000000.
447
		// Precision is in the billions place.
448
		assert_eq!(p.deconstruct(), 800000000);
449
	}
450

            
451
	#[docify::export]
452
	#[test]
453
	fn percent_example() {
454
		let percent = Percent::from_rational(190u32, 400u32);
455
		assert_eq!(percent.deconstruct(), 47);
456
	}
457

            
458
	#[docify::export]
459
	#[test]
460
	fn fixed_u64_block_computation_example() {
461
		// Calculate a very rudimentary on-chain price from supply / demand
462
		// Supply: Cores available per block
463
		// Demand: Cores being ordered per block
464
		let price = FixedU64::from_rational(5u128, 10u128);
465

            
466
		// 0.5 DOT per core
467
		assert_eq!(price, FixedU64::from_float(0.5));
468

            
469
		// Now, the story has changed - lots of demand means we buy as many cores as there
470
		// available.  This also means that price goes up! For the sake of simplicity, we don't care
471
		// about who gets a core - just about our very simple price model
472

            
473
		// Calculate a very rudimentary on-chain price from supply / demand
474
		// Supply: Cores available per block
475
		// Demand: Cores being ordered per block
476
		let price = FixedU64::from_rational(19u128, 10u128);
477

            
478
		// 1.9 DOT per core
479
		assert_eq!(price, FixedU64::from_float(1.9));
480
	}
481

            
482
	#[docify::export]
483
	#[test]
484
	fn fixed_u64() {
485
		// The difference between this and perthings is perthings operates within the relam of [0,
486
		// 1] In cases where we need > 1, we can used fixed types such as FixedU64
487

            
488
		let rational_1 = FixedU64::from_rational(10, 5); //" 200%" aka 2.
489
		let rational_2 = FixedU64::from_rational_with_rounding(5, 10, Rounding::Down); // "50%" aka 0.50...
490

            
491
		assert_eq!(rational_1, (2u64).into());
492
		assert_eq!(rational_2.into_perbill(), Perbill::from_float(0.5));
493
	}
494

            
495
	#[docify::export]
496
	#[test]
497
	fn fixed_u64_operation_example() {
498
		let rational_1 = FixedU64::from_rational(10, 5); // "200%" aka 2.
499
		let rational_2 = FixedU64::from_rational(8, 5); // "160%" aka 1.6.
500

            
501
		let addition = rational_1 + rational_2;
502
		let multiplication = rational_1 * rational_2;
503
		let division = rational_1 / rational_2;
504
		let subtraction = rational_1 - rational_2;
505

            
506
		assert_eq!(addition, FixedU64::from_float(3.6));
507
		assert_eq!(multiplication, FixedU64::from_float(3.2));
508
		assert_eq!(division, FixedU64::from_float(1.25));
509
		assert_eq!(subtraction, FixedU64::from_float(0.4));
510
	}
511
}
512

            
513
#[cfg(test)]
514
mod threshold_compare_tests {
515
	use super::*;
516
	use crate::traits::Saturating;
517
	use core::cmp::Ordering;
518

            
519
	#[test]
520
	fn epsilon_ord_works() {
521
		let b = 115u32;
522
		let e = Perbill::from_percent(10).mul_ceil(b);
523

            
524
		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
525
		assert_eq!((103u32).tcmp(&b, e), Ordering::Equal);
526
		assert_eq!((104u32).tcmp(&b, e), Ordering::Equal);
527
		assert_eq!((115u32).tcmp(&b, e), Ordering::Equal);
528
		assert_eq!((120u32).tcmp(&b, e), Ordering::Equal);
529
		assert_eq!((126u32).tcmp(&b, e), Ordering::Equal);
530
		assert_eq!((127u32).tcmp(&b, e), Ordering::Equal);
531

            
532
		assert_eq!((128u32).tcmp(&b, e), Ordering::Greater);
533
		assert_eq!((102u32).tcmp(&b, e), Ordering::Less);
534
	}
535

            
536
	#[test]
537
	fn epsilon_ord_works_with_small_epc() {
538
		let b = 115u32;
539
		// way less than 1 percent. threshold will be zero. Result should be same as normal ord.
540
		let e = Perbill::from_parts(100) * b;
541

            
542
		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
543
		assert_eq!((103u32).tcmp(&b, e), (103u32).cmp(&b));
544
		assert_eq!((104u32).tcmp(&b, e), (104u32).cmp(&b));
545
		assert_eq!((115u32).tcmp(&b, e), (115u32).cmp(&b));
546
		assert_eq!((120u32).tcmp(&b, e), (120u32).cmp(&b));
547
		assert_eq!((126u32).tcmp(&b, e), (126u32).cmp(&b));
548
		assert_eq!((127u32).tcmp(&b, e), (127u32).cmp(&b));
549

            
550
		assert_eq!((128u32).tcmp(&b, e), (128u32).cmp(&b));
551
		assert_eq!((102u32).tcmp(&b, e), (102u32).cmp(&b));
552
	}
553

            
554
	#[test]
555
	fn peru16_rational_does_not_overflow() {
556
		// A historical example that will panic only for per_thing type that are created with
557
		// maximum capacity of their type, e.g. PerU16.
558
		let _ = PerU16::from_rational(17424870u32, 17424870);
559
	}
560

            
561
	#[test]
562
	fn saturating_mul_works() {
563
		assert_eq!(Saturating::saturating_mul(2, i32::MIN), i32::MIN);
564
		assert_eq!(Saturating::saturating_mul(2, i32::MAX), i32::MAX);
565
	}
566

            
567
	#[test]
568
	fn saturating_pow_works() {
569
		assert_eq!(Saturating::saturating_pow(i32::MIN, 0), 1);
570
		assert_eq!(Saturating::saturating_pow(i32::MAX, 0), 1);
571
		assert_eq!(Saturating::saturating_pow(i32::MIN, 3), i32::MIN);
572
		assert_eq!(Saturating::saturating_pow(i32::MIN, 2), i32::MAX);
573
		assert_eq!(Saturating::saturating_pow(i32::MAX, 2), i32::MAX);
574
	}
575
}