Skip to content

Commit 9846c69

Browse files
committed
tls_codec: add variable-length integer type TlsVarInt
As defined in #[rfc9000]. Also use this type (with an internal thin wrapper `ContentLength`) when encoding/deconding the content length of vectors. [rfc9000]: https://www.rfc-editor.org/rfc/rfc9000#name-variable-length-integer-enc
1 parent c2bbe33 commit 9846c69

File tree

4 files changed

+352
-174
lines changed

4 files changed

+352
-174
lines changed

tls_codec/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ derived.
1919
The crate also provides the following data structures that implement TLS
2020
serialization/deserialization
2121

22-
- `u8`, `u16`, `u32`, `u64`
22+
- `u8`, `u16`, `u32`, `u64`, `TlsVarInt`
2323
- `TlsVecU8`, `TlsVecU16`, `TlsVecU32`
2424
- `SecretTlsVecU8`, `SecretTlsVecU16`, `SecretTlsVecU32`
2525
The same as the `TlsVec*` versions but it implements zeroize, requiring

tls_codec/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ mod arrays;
3939
mod primitives;
4040
mod quic_vec;
4141
mod tls_vec;
42+
mod varint;
4243

4344
pub use tls_vec::{
4445
SecretTlsVecU16, SecretTlsVecU24, SecretTlsVecU32, SecretTlsVecU8, TlsByteSliceU16,
@@ -59,6 +60,8 @@ pub use tls_codec_derive::{
5960
#[cfg(feature = "conditional_deserialization")]
6061
pub use tls_codec_derive::conditionally_deserializable;
6162

63+
pub use varint::TlsVarInt;
64+
6265
/// Errors that are thrown by this crate.
6366
#[derive(Debug, Eq, PartialEq, Clone)]
6467
pub enum Error {

tls_codec/src/quic_vec.rs

+74-173
Original file line numberDiff line numberDiff line change
@@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
2424

2525
use crate::{DeserializeBytes, Error, SerializeBytes, Size};
2626

27-
#[cfg(not(feature = "mls"))]
28-
const MAX_LEN: u64 = (1 << 62) - 1;
29-
#[cfg(not(feature = "mls"))]
30-
const MAX_LEN_LEN_LOG: usize = 3;
3127
#[cfg(feature = "mls")]
32-
const MAX_LEN: u64 = (1 << 30) - 1;
33-
#[cfg(feature = "mls")]
34-
const MAX_LEN_LEN_LOG: usize = 2;
35-
36-
#[inline(always)]
37-
fn check_min_length(length: usize, len_len: usize) -> Result<(), Error> {
38-
if cfg!(feature = "mls") {
39-
// ensure that len_len is minimal for the given length
40-
let min_len_len = length_encoding_bytes(length as u64)?;
41-
if min_len_len != len_len {
42-
return Err(Error::InvalidVectorLength);
43-
}
44-
};
45-
Ok(())
46-
}
28+
const MAX_MLS_LEN: u64 = (1 << 30) - 1;
4729

48-
#[inline(always)]
49-
fn calculate_length(len_len_byte: u8) -> Result<(usize, usize), Error> {
50-
let length: usize = (len_len_byte & 0x3F).into();
51-
let len_len_log = (len_len_byte >> 6).into();
52-
if !cfg!(fuzzing) {
53-
debug_assert!(len_len_log <= MAX_LEN_LEN_LOG);
54-
}
55-
if len_len_log > MAX_LEN_LEN_LOG {
56-
return Err(Error::InvalidVectorLength);
57-
}
58-
let len_len = match len_len_log {
59-
0 => 1,
60-
1 => 2,
61-
2 => 4,
62-
3 => 8,
63-
_ => unreachable!(),
64-
};
65-
Ok((length, len_len))
66-
}
67-
68-
#[inline(always)]
69-
fn read_variable_length_bytes(bytes: &[u8]) -> Result<((usize, usize), &[u8]), Error> {
70-
// The length is encoded in the first two bits of the first byte.
30+
/// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector content in bytes.
31+
///
32+
/// When `mls` feature is enabled, the maximum length is limited to 30-bit. Otherwise, this type is
33+
/// no-op.
34+
struct ContentLength(super::TlsVarInt);
7135

72-
let (len_len_byte, mut remainder) = u8::tls_deserialize_bytes(bytes)?;
36+
impl ContentLength {
37+
#[cfg(not(feature = "mls"))]
38+
#[allow(dead_code)] // used in arbitrary
39+
const MAX: u64 = crate::TlsVarInt::MAX;
7340

74-
let (mut length, len_len) = calculate_length(len_len_byte)?;
41+
#[cfg(feature = "mls")]
42+
const MAX: u64 = MAX_MLS_LEN;
7543

76-
for _ in 1..len_len {
77-
let (next, next_remainder) = u8::tls_deserialize_bytes(remainder)?;
78-
remainder = next_remainder;
79-
length = (length << 8) + usize::from(next);
44+
fn new(value: super::TlsVarInt) -> Result<Self, Error> {
45+
#[cfg(feature = "mls")]
46+
if Self::MAX < value.value() {
47+
return Err(Error::InvalidVectorLength);
48+
}
49+
Ok(Self(value))
8050
}
8151

82-
check_min_length(length, len_len)?;
83-
84-
Ok(((length, len_len), remainder))
52+
fn from_usize(value: usize) -> Result<Self, Error> {
53+
Self::new(super::TlsVarInt::try_new(value.try_into()?)?)
54+
}
8555
}
8656

87-
#[inline(always)]
88-
fn length_encoding_bytes(length: u64) -> Result<usize, Error> {
89-
if !cfg!(fuzzing) {
90-
debug_assert!(length <= MAX_LEN);
91-
}
92-
if length > MAX_LEN {
93-
return Err(Error::InvalidVectorLength);
57+
impl Size for ContentLength {
58+
fn tls_serialized_len(&self) -> usize {
59+
self.0.tls_serialized_len()
9460
}
95-
96-
Ok(if length <= 0x3f {
97-
1
98-
} else if length <= 0x3fff {
99-
2
100-
} else if length <= 0x3fff_ffff {
101-
4
102-
} else {
103-
8
104-
})
10561
}
10662

107-
#[inline(always)]
108-
pub fn write_variable_length(content_length: usize) -> Result<Vec<u8>, Error> {
109-
let len_len = length_encoding_bytes(content_length.try_into()?)?;
110-
if !cfg!(fuzzing) {
111-
debug_assert!(len_len <= 8, "Invalid vector len_len {len_len}");
112-
}
113-
if len_len > 8 {
114-
return Err(Error::LibraryError);
115-
}
116-
let mut length_bytes = vec![0u8; len_len];
117-
match len_len {
118-
1 => length_bytes[0] = 0x00,
119-
2 => length_bytes[0] = 0x40,
120-
4 => length_bytes[0] = 0x80,
121-
8 => length_bytes[0] = 0xc0,
122-
_ => {
123-
if !cfg!(fuzzing) {
124-
debug_assert!(false, "Invalid vector len_len {len_len}");
125-
}
126-
return Err(Error::InvalidVectorLength);
127-
}
128-
}
129-
let mut len = content_length;
130-
for b in length_bytes.iter_mut().rev() {
131-
*b |= (len & 0xFF) as u8;
132-
len >>= 8;
63+
impl DeserializeBytes for ContentLength {
64+
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
65+
let (value, remainder) = super::TlsVarInt::tls_deserialize_bytes(bytes)?;
66+
Ok((Self(value), remainder))
13367
}
134-
135-
Ok(length_bytes)
13668
}
13769

13870
impl<T: Size> Size for Vec<T> {
@@ -152,7 +84,9 @@ impl<T: Size> Size for &Vec<T> {
15284
impl<T: DeserializeBytes> DeserializeBytes for Vec<T> {
15385
#[inline(always)]
15486
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
155-
let ((length, len_len), mut remainder) = read_variable_length_bytes(bytes)?;
87+
let (length, mut remainder) = ContentLength::tls_deserialize_bytes(bytes)?;
88+
let len_len = length.0.bytes_len();
89+
let length: usize = length.0.value().try_into()?;
15690

15791
if length == 0 {
15892
// An empty vector.
@@ -178,11 +112,12 @@ impl<T: SerializeBytes> SerializeBytes for &[T] {
178112
// This requires more computations but the other option would be to buffer
179113
// the entire content, which can end up requiring a lot of memory.
180114
let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len());
181-
let mut length = write_variable_length(content_length)?;
182-
let len_len = length.len();
115+
let length = ContentLength::from_usize(content_length)?;
116+
let len_len = length.0.bytes_len();
183117

184118
let mut out = Vec::with_capacity(content_length + len_len);
185-
out.append(&mut length);
119+
out.resize(len_len, 0);
120+
length.0.write_bytes(&mut out)?;
186121

187122
// Serialize the elements
188123
for e in self.iter() {
@@ -214,11 +149,13 @@ impl<T: Size> Size for &[T] {
214149
#[inline(always)]
215150
fn tls_serialized_len(&self) -> usize {
216151
let content_length = self.iter().fold(0, |acc, e| acc + e.tls_serialized_len());
217-
let len_len = length_encoding_bytes(content_length as u64).unwrap_or({
218-
// We can't do anything about the error unless we change the trait.
219-
// Let's say there's no content for now.
220-
0
221-
});
152+
let len_len = ContentLength::from_usize(content_length)
153+
.map(|content_length| content_length.0.bytes_len())
154+
.unwrap_or({
155+
// We can't do anything about the error unless we change the trait.
156+
// Let's say there's no content for now.
157+
0
158+
});
222159
content_length + len_len
223160
}
224161
}
@@ -327,10 +264,12 @@ impl From<VLBytes> for Vec<u8> {
327264
#[inline(always)]
328265
fn tls_serialize_bytes_len(bytes: &[u8]) -> usize {
329266
let content_length = bytes.len();
330-
let len_len = length_encoding_bytes(content_length as u64).unwrap_or({
331-
// We can't do anything about the error. Let's say there's no content.
332-
0
333-
});
267+
let len_len = ContentLength::from_usize(content_length)
268+
.map(|content_length| content_length.0.bytes_len())
269+
.unwrap_or({
270+
// We can't do anything about the error. Let's say there's no content.
271+
0
272+
});
334273
content_length + len_len
335274
}
336275

@@ -344,22 +283,13 @@ impl Size for VLBytes {
344283
impl DeserializeBytes for VLBytes {
345284
#[inline(always)]
346285
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> {
347-
let ((length, _), remainder) = read_variable_length_bytes(bytes)?;
286+
let (length, remainder) = ContentLength::tls_deserialize_bytes(bytes)?;
287+
let length: usize = length.0.value().try_into()?;
288+
348289
if length == 0 {
349290
return Ok((Self::new(vec![]), remainder));
350291
}
351292

352-
if !cfg!(fuzzing) {
353-
debug_assert!(
354-
length <= MAX_LEN as usize,
355-
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
356-
);
357-
}
358-
if length > MAX_LEN as usize {
359-
return Err(Error::DecodingError(format!(
360-
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
361-
)));
362-
}
363293
match remainder.get(..length).ok_or(Error::EndOfStream) {
364294
Ok(vec) => Ok((Self { vec: vec.to_vec() }, &remainder[length..])),
365295
Err(_e) => {
@@ -422,6 +352,19 @@ pub mod rw {
422352
use super::*;
423353
use crate::{Deserialize, Serialize};
424354

355+
impl Deserialize for ContentLength {
356+
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
357+
ContentLength::new(crate::TlsVarInt::tls_deserialize(bytes)?)
358+
}
359+
}
360+
361+
impl Serialize for ContentLength {
362+
#[inline(always)]
363+
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, Error> {
364+
self.0.tls_serialize(writer)
365+
}
366+
}
367+
425368
/// Read the length of a variable-length vector.
426369
///
427370
/// This function assumes that the reader is at the start of a variable length
@@ -430,26 +373,9 @@ pub mod rw {
430373
/// The length and number of bytes read are returned.
431374
#[inline]
432375
pub fn read_length<R: std::io::Read>(bytes: &mut R) -> Result<(usize, usize), Error> {
433-
// The length is encoded in the first two bits of the first byte.
434-
let mut len_len_byte = [0u8; 1];
435-
if bytes.read(&mut len_len_byte)? == 0 {
436-
// There must be at least one byte for the length.
437-
// If we don't even have a length byte, this is not a valid
438-
// variable-length encoded vector.
439-
return Err(Error::InvalidVectorLength);
440-
}
441-
let len_len_byte = len_len_byte[0];
442-
443-
let (mut length, len_len) = calculate_length(len_len_byte)?;
444-
445-
for _ in 1..len_len {
446-
let mut next = [0u8; 1];
447-
bytes.read_exact(&mut next)?;
448-
length = (length << 8) + usize::from(next[0]);
449-
}
450-
451-
check_min_length(length, len_len)?;
452-
376+
let length = ContentLength::tls_deserialize(bytes)?;
377+
let len_len = length.0.bytes_len();
378+
let length: usize = length.0.value().try_into()?;
453379
Ok((length, len_len))
454380
}
455381

@@ -479,10 +405,7 @@ pub mod rw {
479405
writer: &mut W,
480406
content_length: usize,
481407
) -> Result<usize, Error> {
482-
let buf = super::write_variable_length(content_length)?;
483-
let buf_len = buf.len();
484-
writer.write_all(&buf)?;
485-
Ok(buf_len)
408+
ContentLength::from_usize(content_length)?.tls_serialize(writer)
486409
}
487410

488411
impl<T: Serialize + std::fmt::Debug> Serialize for Vec<T> {
@@ -538,19 +461,7 @@ mod rw_bytes {
538461
// large and write it out.
539462
let content_length = bytes.len();
540463

541-
if !cfg!(fuzzing) {
542-
debug_assert!(
543-
content_length as u64 <= MAX_LEN,
544-
"Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}",
545-
);
546-
}
547-
if content_length as u64 > MAX_LEN {
548-
return Err(Error::InvalidVectorLength);
549-
}
550-
551-
let length_bytes = write_variable_length(content_length)?;
552-
let len_len = length_bytes.len();
553-
writer.write_all(&length_bytes)?;
464+
let len_len = ContentLength::from_usize(content_length)?.tls_serialize(writer)?;
554465

555466
// Now serialize the elements
556467
writer.write_all(bytes)?;
@@ -574,24 +485,14 @@ mod rw_bytes {
574485

575486
impl Deserialize for VLBytes {
576487
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, Error> {
577-
let (length, _) = rw::read_length(bytes)?;
578-
if length == 0 {
488+
let length = ContentLength::tls_deserialize(bytes)?;
489+
490+
if length.0.value() == 0 {
579491
return Ok(Self::new(vec![]));
580492
}
581493

582-
if !cfg!(fuzzing) {
583-
debug_assert!(
584-
length <= MAX_LEN as usize,
585-
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
586-
);
587-
}
588-
if length > MAX_LEN as usize {
589-
return Err(Error::DecodingError(format!(
590-
"Trying to allocate {length} bytes. Only {MAX_LEN} allowed.",
591-
)));
592-
}
593494
let mut result = Self {
594-
vec: vec![0u8; length],
495+
vec: vec![0u8; length.0.value().try_into()?],
595496
};
596497
bytes.read_exact(result.vec.as_mut_slice())?;
597498
Ok(result)
@@ -682,7 +583,7 @@ impl<'a> Arbitrary<'a> for VLBytes {
682583
// We generate an arbitrary `Vec<u8>` ...
683584
let mut vec = Vec::arbitrary(u)?;
684585
// ... and truncate it to `MAX_LEN`.
685-
vec.truncate(MAX_LEN as usize);
586+
vec.truncate(ContentLength::MAX as usize);
686587
// We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing,
687588
// but better make sure that we generate valid instances.
688589

0 commit comments

Comments
 (0)