1
/*
2
Licensed to the Apache Software Foundation (ASF) under one
3
or more contributor license agreements.  See the NOTICE file
4
distributed with this work for additional information
5
regarding copyright ownership.  The ASF licenses this file
6
to you under the Apache License, Version 2.0 (the
7
"License"); you may not use this file except in compliance
8
with the License.  You may obtain a copy of the License at
9

            
10
  http://www.apache.org/licenses/LICENSE-2.0
11

            
12
Unless required by applicable law or agreed to in writing,
13
software distributed under the License is distributed on an
14
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
KIND, either express or implied.  See the License for the
16
specific language governing permissions and limitations
17
under the License.
18
*/
19

            
20
use super::big::Big;
21
use super::dbig::DBig;
22
use super::fp::FP;
23
use super::fp2::FP2;
24
use super::rom::{
25
    H2C_L, HASH_ALGORITHM, MODULUS, SSWU_A1, SSWU_A2_A, SSWU_A2_B, SSWU_B1, SSWU_B2_A, SSWU_B2_B,
26
    SSWU_Z1, SSWU_Z2_A, SSWU_Z2_B,
27
};
28
use crate::errors::AmclError;
29
use crate::hash256::{BLOCK_SIZE as SHA256_BLOCK_SIZE, HASH256, HASH_BYTES as SHA256_HASH_BYTES};
30
use crate::hash384::{BLOCK_SIZE as SHA384_BLOCK_SIZE, HASH384, HASH_BYTES as SHA384_HASH_BYTES};
31
use crate::hash512::{BLOCK_SIZE as SHA512_BLOCK_SIZE, HASH512, HASH_BYTES as SHA512_HASH_BYTES};
32
use crate::std::{vec, Vec};
33

            
34
/// Oversized DST padding
35
pub const OVERSIZED_DST: &[u8] = b"H2C-OVERSIZE-DST-";
36

            
37
#[derive(Copy, Clone)]
38
pub enum HashAlgorithm {
39
    Sha256,
40
    Sha384,
41
    Sha512,
42
}
43

            
44
impl HashAlgorithm {
45
    pub fn length(&self) -> usize {
46
        match self {
47
            HashAlgorithm::Sha256 => SHA256_HASH_BYTES,
48
            HashAlgorithm::Sha384 => SHA384_HASH_BYTES,
49
            HashAlgorithm::Sha512 => SHA512_HASH_BYTES,
50
        }
51
    }
52

            
53
    pub fn block_size(&self) -> usize {
54
        match self {
55
            HashAlgorithm::Sha256 => SHA256_BLOCK_SIZE,
56
            HashAlgorithm::Sha384 => SHA384_BLOCK_SIZE,
57
            HashAlgorithm::Sha512 => SHA512_BLOCK_SIZE,
58
        }
59
    }
60
}
61

            
62
/// Hash a message
63
pub fn hash(msg: &[u8], hash_function: HashAlgorithm) -> Vec<u8> {
64
    match hash_function {
65
        HashAlgorithm::Sha256 => {
66
            let mut hash = HASH256::new();
67
            hash.init();
68
            hash.process_array(msg);
69
            hash.hash().to_vec()
70
        }
71
        HashAlgorithm::Sha384 => {
72
            let mut hash = HASH384::new();
73
            hash.init();
74
            hash.process_array(msg);
75
            hash.hash().to_vec()
76
        }
77
        HashAlgorithm::Sha512 => {
78
            let mut hash = HASH512::new();
79
            hash.init();
80
            hash.process_array(msg);
81
            hash.hash().to_vec()
82
        }
83
    }
84
}
85

            
86
// Hash To Field - Fp
87
//
88
// Take a message as bytes and convert it to a Field Point
89
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-5.3
90
pub fn hash_to_field_fp(msg: &[u8], count: usize, dst: &[u8]) -> Result<Vec<FP>, AmclError> {
91
    let m = 1;
92
    let p = Big::new_ints(&MODULUS);
93

            
94
    let len_in_bytes = count * m * H2C_L;
95
    let pseudo_random_bytes = expand_message_xmd(msg, len_in_bytes, dst)?;
96

            
97
    let mut u: Vec<FP> = Vec::with_capacity(count as usize);
98
    for i in 0..count as usize {
99
        let elm_offset = H2C_L as usize * i * m as usize;
100
        let mut dbig =
101
            DBig::from_bytes(&pseudo_random_bytes[elm_offset..elm_offset + H2C_L as usize]);
102
        let e: Big = dbig.dmod(&p);
103
        u.push(FP::new_big(e));
104
    }
105
    Ok(u)
106
}
107

            
108
// Hash To Field - Fp2
109
//
110
// Take a message as bytes and convert it to a vector of Field Points with extension degree 2.
111
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-5.3
112
pub fn hash_to_field_fp2(msg: &[u8], count: usize, dst: &[u8]) -> Result<Vec<FP2>, AmclError> {
113
    let m = 2;
114
    let p = Big::new_ints(&MODULUS);
115

            
116
    let len_in_bytes = count * m * H2C_L;
117

            
118
    let pseudo_random_bytes = expand_message_xmd(msg, len_in_bytes, dst)?;
119

            
120
    let mut u: Vec<FP2> = Vec::with_capacity(count as usize);
121
    for i in 0..count as usize {
122
        let mut e: Vec<Big> = Vec::with_capacity(m as usize);
123
        for j in 0..m as usize {
124
            let elm_offset = H2C_L as usize * (j + i * m as usize);
125
            let mut big =
126
                DBig::from_bytes(&pseudo_random_bytes[elm_offset..elm_offset + H2C_L as usize]);
127
            e.push(big.dmod(&p));
128
        }
129
        u.push(FP2::new_bigs(e[0].clone(), e[1].clone()));
130
    }
131
    Ok(u)
132
}
133

            
134
// Expand Message XMD
135
//
136
// Take a message and convert it to pseudo random bytes of specified length
137
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-5.4
138
fn expand_message_xmd(msg: &[u8], len_in_bytes: usize, dst: &[u8]) -> Result<Vec<u8>, AmclError> {
139
    // ell = ceiling(len_in_bytes / b_in_bytes)
140
    let ell = (len_in_bytes + HASH_ALGORITHM.length() - 1) / HASH_ALGORITHM.length();
141

            
142
    // Error if length of output less than 255 bytes
143
    if ell > 255 {
144
        return Err(AmclError::HashToFieldError);
145
    }
146

            
147
    // Create DST prime as (dst.len() || dst)
148
    let dst_prime = if dst.len() > 255 {
149
        // DST too long, shorten to H("H2C-OVERSIZE-DST-" || dst)
150
        let mut tmp = OVERSIZED_DST.to_vec();
151
        tmp.extend_from_slice(dst);
152
        let mut tmp = hash(&tmp, HASH_ALGORITHM).to_vec();
153
        tmp.push(HASH_ALGORITHM.length() as u8);
154
        tmp
155
    } else {
156
        // DST correct size, append length as a single byte
157
        let mut prime = dst.to_vec();
158
        prime.push(dst.len() as u8);
159
        prime
160
    };
161

            
162
    let mut pseudo_random_bytes: Vec<u8> = vec![];
163
    let mut b: Vec<Vec<u8>> = vec![vec![]; 2];
164

            
165
    // Set b[0] to H(Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime)
166
    // l_i_b_str = I2OSP(len_in_bytes, 2)
167
    let mut tmp = vec![0; HASH_ALGORITHM.block_size()];
168
    tmp.extend_from_slice(msg);
169
    let l_i_b_str: [u8; 2] = (len_in_bytes as u16).to_be_bytes();
170
    tmp.extend_from_slice(&l_i_b_str);
171
    tmp.push(0u8);
172
    tmp.extend_from_slice(&dst_prime);
173
    b[0] = hash(&tmp, HASH_ALGORITHM);
174

            
175
    // Set b[1] to H(b_0 || I2OSP(1, 1) || DST_prime)
176
    tmp = b[0].clone();
177
    tmp.push(1u8);
178
    tmp.extend_from_slice(&dst_prime);
179
    b[1] = hash(&tmp, HASH_ALGORITHM);
180

            
181
    pseudo_random_bytes.extend_from_slice(&b[1]);
182

            
183
    for i in 2..=ell {
184
        // Set b[i] to H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
185
        tmp = b[0]
186
            .iter()
187
            .enumerate()
188
            .map(|(j, b_0)| {
189
                // Perform strxor(b[0], b[i-1])
190
                b_0 ^ b[i - 1][j] // b[i].len() will all be 32 bytes as they are SHA256 output.
191
            })
192
            .collect();
193
        tmp.push(i as u8); // i < 256
194
        tmp.extend_from_slice(&dst_prime);
195
        b.push(hash(&tmp, HASH_ALGORITHM));
196

            
197
        pseudo_random_bytes.extend_from_slice(&b[i]);
198
    }
199

            
200
    // Take required length
201
    Ok(pseudo_random_bytes[..len_in_bytes as usize].to_vec())
202
}
203

            
204
// Simplified Shallue-van de Woestijne-Ulas Method - Fp
205
//
206
// Returns projectives as (XZ, YZ, Z)
207
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-6.6.2
208
pub fn simplified_swu_fp(u: FP) -> (FP, FP) {
209
    let sswu_a = FP::new_big(Big::new_ints(&SSWU_A1));
210
    let sswu_b = FP::new_big(Big::new_ints(&SSWU_B1));
211
    let sswu_z = FP::new_big(Big::new_ints(&SSWU_Z1));
212

            
213
    // tmp1 = Z * u^2
214
    // tv1 = 1 / (Z^2 * u^4 + Z * u^2)
215
    let mut tmp1 = u.clone();
216
    tmp1.sqr();
217
    tmp1.mul(&sswu_z);
218
    let mut tv1 = tmp1.clone();
219
    tv1.sqr();
220
    tv1.add(&tmp1);
221
    tv1.inverse();
222

            
223
    // x = (-B / A) * (1 + tv1)
224
    let mut x = tv1.clone();
225
    x.add(&FP::new_int(1));
226
    x.mul(&sswu_b); // b * (Z^2 * u^4 + Z * u^2 + 1)
227
    x.neg();
228
    let mut a_inverse = sswu_a.clone();
229
    a_inverse.inverse();
230
    x.mul(&a_inverse);
231

            
232
    // Deal with case where Z^2 * u^4 + Z * u^2 == 0
233
    if tv1.is_zilch() {
234
        // x = B / (Z * A)
235
        x = sswu_z.clone();
236
        x.inverse();
237
        x.mul(&sswu_b);
238
        x.mul(&a_inverse);
239
    }
240

            
241
    // gx = x^3 + A * x + B
242
    let mut gx = x.clone();
243
    gx.sqr();
244
    gx.add(&sswu_a);
245
    gx.mul(&x);
246
    gx.add(&sswu_b);
247

            
248
    // y = sqrt(gx)
249
    let mut y = gx.clone();
250
    let mut y = y.sqrt();
251

            
252
    // Check y is valid square root
253
    let mut y2 = y.clone();
254
    y2.sqr();
255
    if !gx.equals(&y2) {
256
        // x = x * Z^2 * u
257
        x.mul(&tmp1);
258

            
259
        // gx = x^3 + A * x + B
260
        let mut gx = x.clone();
261
        gx.sqr();
262
        gx.add(&sswu_a);
263
        gx.mul(&x);
264
        gx.add(&sswu_b);
265

            
266
        y = gx.sqrt();
267
        y2 = y.clone();
268
        y2.sqr();
269
        assert_eq!(gx, y2, "Hash to Curve SSWU failure - no square roots");
270
    }
271

            
272
    // Negate y if y and t are opposite in sign
273
    if u.sgn0() != y.sgn0() {
274
        y.neg();
275
    }
276

            
277
    (x, y)
278
}
279

            
280
// Simplified Shallue-van de Woestijne-Ulas Method - Fp2
281
//
282
// Returns projectives as (X, Y)
283
// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-09#section-6.6.2
284
pub fn simplified_swu_fp2(u: FP2) -> (FP2, FP2) {
285
    let sswu_a = FP2::new_bigs(Big::new_ints(&SSWU_A2_A), Big::new_ints(&SSWU_A2_B));
286
    let sswu_b = FP2::new_bigs(Big::new_ints(&SSWU_B2_A), Big::new_ints(&SSWU_B2_B));
287
    let sswu_z = FP2::new_bigs(Big::new_ints(&SSWU_Z2_A), Big::new_ints(&SSWU_Z2_B));
288

            
289
    // tmp1 = Z * u^2
290
    // tv1 = 1 / (Z^2 * u^4 + Z * u^2)
291
    let mut tmp1 = u.clone();
292
    tmp1.sqr();
293
    tmp1.mul(&sswu_z);
294
    let mut tv1 = tmp1.clone();
295
    tv1.sqr();
296
    tv1.add(&tmp1);
297
    tv1.inverse();
298

            
299
    // x = (-B / A) * (1 + tv1)
300
    let mut x = tv1.clone();
301
    x.add(&FP2::new_ints(1, 0));
302
    x.mul(&sswu_b); // b * (Z^2 * u^4 + Z * u^2 + 1)
303
    x.neg();
304
    let mut a_inverse = sswu_a.clone();
305
    a_inverse.inverse();
306
    x.mul(&a_inverse);
307

            
308
    // Deal with case where Z^2 * u^4 + Z * u^2 == 0
309
    if tv1.is_zilch() {
310
        // x = B / (Z * A)
311
        x = sswu_z.clone();
312
        x.inverse();
313
        x.mul(&sswu_b);
314
        x.mul(&a_inverse);
315
    }
316

            
317
    // gx = x^3 + A * x + B
318
    let mut gx = x.clone();
319
    gx.sqr();
320
    gx.add(&sswu_a);
321
    gx.mul(&x);
322
    gx.add(&sswu_b);
323

            
324
    // y = sqrt(gx)
325
    let mut y = gx.clone();
326
    if !y.sqrt() {
327
        // x = x * Z * u^2
328
        x.mul(&tmp1);
329

            
330
        // gx = x^3 + A * x + B
331
        let mut gx = x.clone();
332
        gx.sqr();
333
        gx.add(&sswu_a);
334
        gx.mul(&x);
335
        gx.add(&sswu_b);
336

            
337
        y = gx;
338
        assert!(y.sqrt(), "Hash to Curve SSWU failure - no square roots");
339
    }
340

            
341
    // Negate y if y and t are opposite in sign
342
    if u.sgn0() != y.sgn0() {
343
        y.neg();
344
    }
345

            
346
    (x, y)
347
}