1
// This file is part of Substrate.
2

            
3
// Copyright (C) Parity Technologies (UK) Ltd.
4
// SPDX-License-Identifier: Apache-2.0
5

            
6
// Licensed under the Apache License, Version 2.0 (the "License");
7
// you may not use this file except in compliance with the License.
8
// You may obtain a copy of the License at
9
//
10
// 	http://www.apache.org/licenses/LICENSE-2.0
11
//
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
17

            
18
//! Provides some utilities to define a piecewise linear function.
19

            
20
use crate::{
21
	traits::{AtLeast32BitUnsigned, SaturatedConversion},
22
	Perbill,
23
};
24
use core::ops::Sub;
25
use scale_info::TypeInfo;
26

            
27
/// Piecewise Linear function in [0, 1] -> [0, 1].
28
#[derive(PartialEq, Eq, sp_core::RuntimeDebug, TypeInfo)]
29
pub struct PiecewiseLinear<'a> {
30
	/// Array of points. Must be in order from the lowest abscissas to the highest.
31
	pub points: &'a [(Perbill, Perbill)],
32
	/// The maximum value that can be returned.
33
	pub maximum: Perbill,
34
}
35

            
36
fn abs_sub<N: Ord + Sub<Output = N> + Clone>(a: N, b: N) -> N where {
37
	a.clone().max(b.clone()) - a.min(b)
38
}
39

            
40
impl<'a> PiecewiseLinear<'a> {
41
	/// Compute `f(n/d)*d` with `n <= d`. This is useful to avoid loss of precision.
42
	pub fn calculate_for_fraction_times_denominator<N>(&self, n: N, d: N) -> N
43
	where
44
		N: AtLeast32BitUnsigned + Clone,
45
	{
46
		let n = n.min(d.clone());
47

            
48
		if self.points.is_empty() {
49
			return N::zero()
50
		}
51

            
52
		let next_point_index = self.points.iter().position(|p| n < p.0 * d.clone());
53

            
54
		let (prev, next) = if let Some(next_point_index) = next_point_index {
55
			if let Some(previous_point_index) = next_point_index.checked_sub(1) {
56
				(self.points[previous_point_index], self.points[next_point_index])
57
			} else {
58
				// There is no previous points, take first point ordinate
59
				return self.points.first().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
60
			}
61
		} else {
62
			// There is no next points, take last point ordinate
63
			return self.points.last().map(|p| p.1).unwrap_or_else(Perbill::zero) * d
64
		};
65

            
66
		let delta_y = multiply_by_rational_saturating(
67
			abs_sub(n.clone(), prev.0 * d.clone()),
68
			abs_sub(next.1.deconstruct(), prev.1.deconstruct()),
69
			// Must not saturate as prev abscissa > next abscissa
70
			next.0.deconstruct().saturating_sub(prev.0.deconstruct()),
71
		);
72

            
73
		// If both subtractions are same sign then result is positive
74
		if (n > prev.0 * d.clone()) == (next.1.deconstruct() > prev.1.deconstruct()) {
75
			(prev.1 * d).saturating_add(delta_y)
76
		// Otherwise result is negative
77
		} else {
78
			(prev.1 * d).saturating_sub(delta_y)
79
		}
80
	}
81
}
82

            
83
// Compute value * p / q.
84
// This is guaranteed not to overflow on whatever values nor lose precision.
85
// `q` must be superior to zero.
86
fn multiply_by_rational_saturating<N>(value: N, p: u32, q: u32) -> N
87
where
88
	N: AtLeast32BitUnsigned + Clone,
89
{
90
	let q = q.max(1);
91

            
92
	// Mul can saturate if p > q
93
	let result_divisor_part = (value.clone() / q.into()).saturating_mul(p.into());
94

            
95
	let result_remainder_part = {
96
		let rem = value % q.into();
97

            
98
		// Fits into u32 because q is u32 and remainder < q
99
		let rem_u32 = rem.saturated_into::<u32>();
100

            
101
		// Multiplication fits into u64 as both term are u32
102
		let rem_part = rem_u32 as u64 * p as u64 / q as u64;
103

            
104
		// Can saturate if p > q
105
		rem_part.saturated_into::<N>()
106
	};
107

            
108
	// Can saturate if p > q
109
	result_divisor_part.saturating_add(result_remainder_part)
110
}
111

            
112
#[test]
113
fn test_multiply_by_rational_saturating() {
114
	let div = 100u32;
115
	for value in 0..=div {
116
		for p in 0..=div {
117
			for q in 1..=div {
118
				let value: u64 =
119
					(value as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
120
				let p = (p as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
121
				let q = (q as u64 * u32::MAX as u64 / div as u64).try_into().unwrap();
122

            
123
				assert_eq!(
124
					multiply_by_rational_saturating(value, p, q),
125
					(value as u128 * p as u128 / q as u128).try_into().unwrap_or(u64::MAX)
126
				);
127
			}
128
		}
129
	}
130
}
131

            
132
#[test]
133
fn test_calculate_for_fraction_times_denominator() {
134
	let curve = PiecewiseLinear {
135
		points: &[
136
			(Perbill::from_parts(0_000_000_000), Perbill::from_parts(0_500_000_000)),
137
			(Perbill::from_parts(0_500_000_000), Perbill::from_parts(1_000_000_000)),
138
			(Perbill::from_parts(1_000_000_000), Perbill::from_parts(0_000_000_000)),
139
		],
140
		maximum: Perbill::from_parts(1_000_000_000),
141
	};
142

            
143
	pub fn formal_calculate_for_fraction_times_denominator(n: u64, d: u64) -> u64 {
144
		if n <= Perbill::from_parts(0_500_000_000) * d {
145
			n + d / 2
146
		} else {
147
			(d as u128 * 2 - n as u128 * 2).try_into().unwrap()
148
		}
149
	}
150

            
151
	let div = 100u32;
152
	for d in 0..=div {
153
		for n in 0..=d {
154
			let d: u64 = (d as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
155
			let n: u64 = (n as u128 * u64::MAX as u128 / div as u128).try_into().unwrap();
156

            
157
			let res = curve.calculate_for_fraction_times_denominator(n, d);
158
			let expected = formal_calculate_for_fraction_times_denominator(n, d);
159

            
160
			assert!(abs_sub(res, expected) <= 1);
161
		}
162
	}
163
}