1
use crate::ber::*;
2
use crate::der_constraint_fail_if;
3
use crate::error::*;
4
#[cfg(feature = "std")]
5
use crate::ToDer;
6
use crate::{BerParser, Class, DerParser, DynTagged, FromBer, FromDer, Length, Tag, ToStatic};
7
use alloc::borrow::Cow;
8
use core::convert::TryFrom;
9
use nom::bytes::streaming::take;
10

            
11
/// BER/DER object header (identifier and length)
12
#[derive(Clone, Debug)]
13
pub struct Header<'a> {
14
    /// Object class: universal, application, context-specific, or private
15
    pub(crate) class: Class,
16
    /// Constructed attribute: true if constructed, else false
17
    pub(crate) constructed: bool,
18
    /// Tag number
19
    pub(crate) tag: Tag,
20
    /// Object length: value if definite, or indefinite
21
    pub(crate) length: Length,
22

            
23
    /// Optionally, the raw encoding of the tag
24
    ///
25
    /// This is useful in some cases, where different representations of the same
26
    /// BER tags have different meanings (BER only)
27
    pub(crate) raw_tag: Option<Cow<'a, [u8]>>,
28
}
29

            
30
impl<'a> Header<'a> {
31
    /// Build a new BER/DER header from the provided values
32
    pub const fn new(class: Class, constructed: bool, tag: Tag, length: Length) -> Self {
33
        Header {
34
            tag,
35
            constructed,
36
            class,
37
            length,
38
            raw_tag: None,
39
        }
40
    }
41

            
42
    /// Build a new BER/DER header from the provided tag, with default values for other fields
43
    #[inline]
44
    pub const fn new_simple(tag: Tag) -> Self {
45
        let constructed = matches!(tag, Tag::Sequence | Tag::Set);
46
        Self::new(Class::Universal, constructed, tag, Length::Definite(0))
47
    }
48

            
49
    /// Set the class of this `Header`
50
    #[inline]
51
    pub fn with_class(self, class: Class) -> Self {
52
        Self { class, ..self }
53
    }
54

            
55
    /// Set the constructed flags of this `Header`
56
    #[inline]
57
    pub fn with_constructed(self, constructed: bool) -> Self {
58
        Self {
59
            constructed,
60
            ..self
61
        }
62
    }
63

            
64
    /// Set the tag of this `Header`
65
    #[inline]
66
    pub fn with_tag(self, tag: Tag) -> Self {
67
        Self { tag, ..self }
68
    }
69

            
70
    /// Set the length of this `Header`
71
    #[inline]
72
    pub fn with_length(self, length: Length) -> Self {
73
        Self { length, ..self }
74
    }
75

            
76
    /// Update header to add reference to raw tag
77
    #[inline]
78
    pub fn with_raw_tag(self, raw_tag: Option<Cow<'a, [u8]>>) -> Self {
79
        Header { raw_tag, ..self }
80
    }
81

            
82
    /// Return the class of this header.
83
    #[inline]
84
    pub const fn class(&self) -> Class {
85
        self.class
86
    }
87

            
88
    /// Return true if this header has the 'constructed' flag.
89
    #[inline]
90
    pub const fn constructed(&self) -> bool {
91
        self.constructed
92
    }
93

            
94
    /// Return the tag of this header.
95
    #[inline]
96
    pub const fn tag(&self) -> Tag {
97
        self.tag
98
    }
99

            
100
    /// Return the length of this header.
101
    #[inline]
102
    pub const fn length(&self) -> Length {
103
        self.length
104
    }
105

            
106
    /// Return the raw tag encoding, if it was stored in this object
107
    #[inline]
108
    pub fn raw_tag(&self) -> Option<&[u8]> {
109
        self.raw_tag.as_ref().map(|cow| cow.as_ref())
110
    }
111

            
112
    /// Test if object is primitive
113
    #[inline]
114
    pub const fn is_primitive(&self) -> bool {
115
        !self.constructed
116
    }
117

            
118
    /// Test if object is constructed
119
    #[inline]
120
    pub const fn is_constructed(&self) -> bool {
121
        self.constructed
122
    }
123

            
124
    /// Return error if class is not the expected class
125
    #[inline]
126
    pub const fn assert_class(&self, class: Class) -> Result<()> {
127
        self.class.assert_eq(class)
128
    }
129

            
130
    /// Return error if tag is not the expected tag
131
    #[inline]
132
    pub const fn assert_tag(&self, tag: Tag) -> Result<()> {
133
        self.tag.assert_eq(tag)
134
    }
135

            
136
    /// Return error if object is not primitive
137
    #[inline]
138
    pub const fn assert_primitive(&self) -> Result<()> {
139
        if self.is_primitive() {
140
            Ok(())
141
        } else {
142
            Err(Error::ConstructUnexpected)
143
        }
144
    }
145

            
146
    /// Return error if object is primitive
147
    #[inline]
148
    pub const fn assert_constructed(&self) -> Result<()> {
149
        if !self.is_primitive() {
150
            Ok(())
151
        } else {
152
            Err(Error::ConstructExpected)
153
        }
154
    }
155

            
156
    /// Test if object class is Universal
157
    #[inline]
158
    pub const fn is_universal(&self) -> bool {
159
        self.class as u8 == Class::Universal as u8
160
    }
161
    /// Test if object class is Application
162
    #[inline]
163
    pub const fn is_application(&self) -> bool {
164
        self.class as u8 == Class::Application as u8
165
    }
166
    /// Test if object class is Context-specific
167
    #[inline]
168
    pub const fn is_contextspecific(&self) -> bool {
169
        self.class as u8 == Class::ContextSpecific as u8
170
    }
171
    /// Test if object class is Private
172
    #[inline]
173
    pub const fn is_private(&self) -> bool {
174
        self.class as u8 == Class::Private as u8
175
    }
176

            
177
    /// Return error if object length is definite
178
    #[inline]
179
    pub const fn assert_definite(&self) -> Result<()> {
180
        if self.length.is_definite() {
181
            Ok(())
182
        } else {
183
            Err(Error::DerConstraintFailed(DerConstraint::IndefiniteLength))
184
        }
185
    }
186

            
187
    /// Get the content following a BER header
188
    #[inline]
189
    pub fn parse_ber_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
190
        // defaults to maximum depth 8
191
        // depth is used only if BER, and length is indefinite
192
        BerParser::get_object_content(i, self, 8)
193
    }
194

            
195
    /// Get the content following a DER header
196
    #[inline]
197
    pub fn parse_der_content<'i>(&'_ self, i: &'i [u8]) -> ParseResult<'i, &'i [u8]> {
198
        self.assert_definite()?;
199
        DerParser::get_object_content(i, self, 8)
200
    }
201
}
202

            
203
impl From<Tag> for Header<'_> {
204
    #[inline]
205
    fn from(tag: Tag) -> Self {
206
        let constructed = matches!(tag, Tag::Sequence | Tag::Set);
207
        Self::new(Class::Universal, constructed, tag, Length::Definite(0))
208
    }
209
}
210

            
211
impl<'a> ToStatic for Header<'a> {
212
    type Owned = Header<'static>;
213

            
214
    fn to_static(&self) -> Self::Owned {
215
        let raw_tag: Option<Cow<'static, [u8]>> =
216
            self.raw_tag.as_ref().map(|b| Cow::Owned(b.to_vec()));
217
        Header {
218
            tag: self.tag,
219
            constructed: self.constructed,
220
            class: self.class,
221
            length: self.length,
222
            raw_tag,
223
        }
224
    }
225
}
226

            
227
impl<'a> FromBer<'a> for Header<'a> {
228
    fn from_ber(bytes: &'a [u8]) -> ParseResult<Self> {
229
        let (i1, el) = parse_identifier(bytes)?;
230
        let class = match Class::try_from(el.0) {
231
            Ok(c) => c,
232
            Err(_) => unreachable!(), // Cannot fail, we have read exactly 2 bits
233
        };
234
        let (i2, len) = parse_ber_length_byte(i1)?;
235
        let (i3, len) = match (len.0, len.1) {
236
            (0, l1) => {
237
                // Short form: MSB is 0, the rest encodes the length (which can be 0) (8.1.3.4)
238
                (i2, Length::Definite(usize::from(l1)))
239
            }
240
            (_, 0) => {
241
                // Indefinite form: MSB is 1, the rest is 0 (8.1.3.6)
242
                // If encoding is primitive, definite form shall be used (8.1.3.2)
243
                if el.1 == 0 {
244
                    return Err(nom::Err::Error(Error::ConstructExpected));
245
                }
246
                (i2, Length::Indefinite)
247
            }
248
            (_, l1) => {
249
                // if len is 0xff -> error (8.1.3.5)
250
                if l1 == 0b0111_1111 {
251
                    return Err(nom::Err::Error(Error::InvalidLength));
252
                }
253
                let (i3, llen) = take(l1)(i2)?;
254
                match bytes_to_u64(llen) {
255
                    Ok(l) => {
256
                        let l =
257
                            usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
258
                        (i3, Length::Definite(l))
259
                    }
260
                    Err(_) => {
261
                        return Err(nom::Err::Error(Error::InvalidLength));
262
                    }
263
                }
264
            }
265
        };
266
        let constructed = el.1 != 0;
267
        let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
268
        Ok((i3, hdr))
269
    }
270
}
271

            
272
impl<'a> FromDer<'a> for Header<'a> {
273
    fn from_der(bytes: &'a [u8]) -> ParseResult<Self> {
274
        let (i1, el) = parse_identifier(bytes)?;
275
        let class = match Class::try_from(el.0) {
276
            Ok(c) => c,
277
            Err(_) => unreachable!(), // Cannot fail, we have read exactly 2 bits
278
        };
279
        let (i2, len) = parse_ber_length_byte(i1)?;
280
        let (i3, len) = match (len.0, len.1) {
281
            (0, l1) => {
282
                // Short form: MSB is 0, the rest encodes the length (which can be 0) (8.1.3.4)
283
                (i2, Length::Definite(usize::from(l1)))
284
            }
285
            (_, 0) => {
286
                // Indefinite form is not allowed in DER (10.1)
287
                return Err(nom::Err::Error(Error::DerConstraintFailed(
288
                    DerConstraint::IndefiniteLength,
289
                )));
290
            }
291
            (_, l1) => {
292
                // if len is 0xff -> error (8.1.3.5)
293
                if l1 == 0b0111_1111 {
294
                    return Err(nom::Err::Error(Error::InvalidLength));
295
                }
296
                // DER(9.1) if len is 0 (indefinite form), obj must be constructed
297
                der_constraint_fail_if!(
298
                    &i[1..],
299
                    len.1 == 0 && el.1 != 1,
300
                    DerConstraint::NotConstructed
301
                );
302
                let (i3, llen) = take(l1)(i2)?;
303
                match bytes_to_u64(llen) {
304
                    Ok(l) => {
305
                        // DER: should have been encoded in short form (< 127)
306
                        // XXX der_constraint_fail_if!(i, l < 127);
307
                        let l =
308
                            usize::try_from(l).or(Err(nom::Err::Error(Error::InvalidLength)))?;
309
                        (i3, Length::Definite(l))
310
                    }
311
                    Err(_) => {
312
                        return Err(nom::Err::Error(Error::InvalidLength));
313
                    }
314
                }
315
            }
316
        };
317
        let constructed = el.1 != 0;
318
        let hdr = Header::new(class, constructed, Tag(el.2), len).with_raw_tag(Some(el.3.into()));
319
        Ok((i3, hdr))
320
    }
321
}
322

            
323
impl DynTagged for (Class, bool, Tag) {
324
    fn tag(&self) -> Tag {
325
        self.2
326
    }
327
}
328

            
329
#[cfg(feature = "std")]
330
impl ToDer for (Class, bool, Tag) {
331
    fn to_der_len(&self) -> Result<usize> {
332
        let (_, _, tag) = self;
333
        match tag.0 {
334
            0..=30 => Ok(1),
335
            t => {
336
                let mut sz = 1;
337
                let mut val = t;
338
                loop {
339
                    if val <= 127 {
340
                        return Ok(sz + 1);
341
                    } else {
342
                        val >>= 7;
343
                        sz += 1;
344
                    }
345
                }
346
            }
347
        }
348
    }
349

            
350
    fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
351
        let (class, constructed, tag) = self;
352
        let b0 = (*class as u8) << 6;
353
        let b0 = b0 | if *constructed { 0b10_0000 } else { 0 };
354
        if tag.0 > 30 {
355
            let mut val = tag.0;
356

            
357
            const BUF_SZ: usize = 8;
358
            let mut buffer = [0u8; BUF_SZ];
359
            let mut current_index = BUF_SZ - 1;
360

            
361
            // first byte: class+constructed+0x1f
362
            let b0 = b0 | 0b1_1111;
363
            let mut sz = writer.write(&[b0])?;
364

            
365
            // now write bytes from right (last) to left
366

            
367
            // last encoded byte
368
            buffer[current_index] = (val & 0x7f) as u8;
369
            val >>= 7;
370

            
371
            while val > 0 {
372
                current_index -= 1;
373
                if current_index == 0 {
374
                    return Err(SerializeError::InvalidLength);
375
                }
376
                buffer[current_index] = (val & 0x7f) as u8 | 0x80;
377
                val >>= 7;
378
            }
379

            
380
            sz += writer.write(&buffer[current_index..])?;
381
            Ok(sz)
382
        } else {
383
            let b0 = b0 | (tag.0 as u8);
384
            let sz = writer.write(&[b0])?;
385
            Ok(sz)
386
        }
387
    }
388

            
389
    fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
390
        Ok(0)
391
    }
392
}
393

            
394
impl DynTagged for Header<'_> {
395
    fn tag(&self) -> Tag {
396
        self.tag
397
    }
398
}
399

            
400
#[cfg(feature = "std")]
401
impl ToDer for Header<'_> {
402
    fn to_der_len(&self) -> Result<usize> {
403
        let tag_len = (self.class, self.constructed, self.tag).to_der_len()?;
404
        let len_len = self.length.to_der_len()?;
405
        Ok(tag_len + len_len)
406
    }
407

            
408
    fn write_der_header(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
409
        let sz = (self.class, self.constructed, self.tag).write_der_header(writer)?;
410
        let sz = sz + self.length.write_der_header(writer)?;
411
        Ok(sz)
412
    }
413

            
414
    fn write_der_content(&self, _writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
415
        Ok(0)
416
    }
417

            
418
    fn write_der_raw(&self, writer: &mut dyn std::io::Write) -> SerializeResult<usize> {
419
        // use raw_tag if present
420
        let sz = match &self.raw_tag {
421
            Some(t) => writer.write(t)?,
422
            None => (self.class, self.constructed, self.tag).write_der_header(writer)?,
423
        };
424
        let sz = sz + self.length.write_der_header(writer)?;
425
        Ok(sz)
426
    }
427
}
428

            
429
/// Compare two BER headers. `len` fields are compared only if both objects have it set (same for `raw_tag`)
430
impl<'a> PartialEq<Header<'a>> for Header<'a> {
431
    fn eq(&self, other: &Header) -> bool {
432
        self.class == other.class
433
            && self.tag == other.tag
434
            && self.constructed == other.constructed
435
            && {
436
                if self.length.is_null() && other.length.is_null() {
437
                    self.length == other.length
438
                } else {
439
                    true
440
                }
441
            }
442
            && {
443
                // it tag is present for both, compare it
444
                if self.raw_tag.as_ref().xor(other.raw_tag.as_ref()).is_none() {
445
                    self.raw_tag == other.raw_tag
446
                } else {
447
                    true
448
                }
449
            }
450
    }
451
}
452

            
453
impl Eq for Header<'_> {}
454

            
455
#[cfg(test)]
456
mod tests {
457
    use crate::*;
458
    use hex_literal::hex;
459

            
460
    /// Generic tests on methods, and coverage tests
461
    #[test]
462
    fn methods_header() {
463
        // Getters
464
        let input = &hex! {"02 01 00"};
465
        let (rem, header) = Header::from_ber(input).expect("parsing header failed");
466
        assert_eq!(header.class(), Class::Universal);
467
        assert_eq!(header.tag(), Tag::Integer);
468
        assert!(header.assert_primitive().is_ok());
469
        assert!(header.assert_constructed().is_err());
470
        assert!(header.is_universal());
471
        assert!(!header.is_application());
472
        assert!(!header.is_private());
473
        assert_eq!(rem, &input[2..]);
474

            
475
        // test PartialEq
476
        let hdr2 = Header::new_simple(Tag::Integer);
477
        assert_eq!(header, hdr2);
478

            
479
        // builder methods
480
        let hdr3 = hdr2
481
            .with_class(Class::ContextSpecific)
482
            .with_constructed(true)
483
            .with_length(Length::Definite(1));
484
        assert!(hdr3.constructed());
485
        assert!(hdr3.is_constructed());
486
        assert!(hdr3.assert_constructed().is_ok());
487
        assert!(hdr3.is_contextspecific());
488
        let xx = hdr3.to_der_vec().expect("serialize failed");
489
        assert_eq!(&xx, &[0xa2, 0x01]);
490

            
491
        // indefinite length
492
        let hdr4 = hdr3.with_length(Length::Indefinite);
493
        assert!(hdr4.assert_definite().is_err());
494
        let xx = hdr4.to_der_vec().expect("serialize failed");
495
        assert_eq!(&xx, &[0xa2, 0x80]);
496

            
497
        // parse_*_content
498
        let hdr = Header::new_simple(Tag(2)).with_length(Length::Definite(1));
499
        let (_, r) = hdr.parse_ber_content(&input[2..]).unwrap();
500
        assert_eq!(r, &input[2..]);
501
        let (_, r) = hdr.parse_der_content(&input[2..]).unwrap();
502
        assert_eq!(r, &input[2..]);
503
    }
504
}