@@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
24
24
25
25
use crate :: { DeserializeBytes , Error , SerializeBytes , Size } ;
26
26
27
- #[ cfg( not( feature = "mls" ) ) ]
28
- const MAX_LEN : u64 = ( 1 << 62 ) - 1 ;
29
- #[ cfg( not( feature = "mls" ) ) ]
30
- const MAX_LEN_LEN_LOG : usize = 3 ;
31
27
#[ cfg( feature = "mls" ) ]
32
- const MAX_LEN : u64 = ( 1 << 30 ) - 1 ;
33
- #[ cfg( feature = "mls" ) ]
34
- const MAX_LEN_LEN_LOG : usize = 2 ;
35
-
36
- #[ inline( always) ]
37
- fn check_min_length ( length : usize , len_len : usize ) -> Result < ( ) , Error > {
38
- if cfg ! ( feature = "mls" ) {
39
- // ensure that len_len is minimal for the given length
40
- let min_len_len = length_encoding_bytes ( length as u64 ) ?;
41
- if min_len_len != len_len {
42
- return Err ( Error :: InvalidVectorLength ) ;
43
- }
44
- } ;
45
- Ok ( ( ) )
46
- }
28
+ const MAX_MLS_LEN : u64 = ( 1 << 30 ) - 1 ;
47
29
48
- #[ inline( always) ]
49
- fn calculate_length ( len_len_byte : u8 ) -> Result < ( usize , usize ) , Error > {
50
- let length: usize = ( len_len_byte & 0x3F ) . into ( ) ;
51
- let len_len_log = ( len_len_byte >> 6 ) . into ( ) ;
52
- if !cfg ! ( fuzzing) {
53
- debug_assert ! ( len_len_log <= MAX_LEN_LEN_LOG ) ;
54
- }
55
- if len_len_log > MAX_LEN_LEN_LOG {
56
- return Err ( Error :: InvalidVectorLength ) ;
57
- }
58
- let len_len = match len_len_log {
59
- 0 => 1 ,
60
- 1 => 2 ,
61
- 2 => 4 ,
62
- 3 => 8 ,
63
- _ => unreachable ! ( ) ,
64
- } ;
65
- Ok ( ( length, len_len) )
66
- }
67
-
68
- #[ inline( always) ]
69
- fn read_variable_length_bytes ( bytes : & [ u8 ] ) -> Result < ( ( usize , usize ) , & [ u8 ] ) , Error > {
70
- // The length is encoded in the first two bits of the first byte.
30
+ /// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector content in bytes.
31
+ ///
32
+ /// When `mls` feature is enabled, the maximum length is limited to 30-bit. Otherwise, this type is
33
+ /// no-op.
34
+ struct ContentLength ( super :: TlsVarInt ) ;
71
35
72
- let ( len_len_byte, mut remainder) = u8:: tls_deserialize_bytes ( bytes) ?;
36
+ impl ContentLength {
37
+ #[ cfg( not( feature = "mls" ) ) ]
38
+ #[ allow( dead_code) ] // used in arbitrary
39
+ const MAX : u64 = crate :: TlsVarInt :: MAX ;
73
40
74
- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
41
+ #[ cfg( feature = "mls" ) ]
42
+ const MAX : u64 = MAX_MLS_LEN ;
75
43
76
- for _ in 1 ..len_len {
77
- let ( next, next_remainder) = u8:: tls_deserialize_bytes ( remainder) ?;
78
- remainder = next_remainder;
79
- length = ( length << 8 ) + usize:: from ( next) ;
44
+ fn new ( value : super :: TlsVarInt ) -> Result < Self , Error > {
45
+ #[ cfg( feature = "mls" ) ]
46
+ if Self :: MAX < value. value ( ) {
47
+ return Err ( Error :: InvalidVectorLength ) ;
48
+ }
49
+ Ok ( Self ( value) )
80
50
}
81
51
82
- check_min_length ( length , len_len ) ? ;
83
-
84
- Ok ( ( ( length , len_len ) , remainder ) )
52
+ fn from_usize ( value : usize ) -> Result < Self , Error > {
53
+ Self :: new ( super :: TlsVarInt :: try_new ( value . try_into ( ) ? ) ? )
54
+ }
85
55
}
86
56
87
- #[ inline( always) ]
88
- fn length_encoding_bytes ( length : u64 ) -> Result < usize , Error > {
89
- if !cfg ! ( fuzzing) {
90
- debug_assert ! ( length <= MAX_LEN ) ;
91
- }
92
- if length > MAX_LEN {
93
- return Err ( Error :: InvalidVectorLength ) ;
57
+ impl Size for ContentLength {
58
+ fn tls_serialized_len ( & self ) -> usize {
59
+ self . 0 . tls_serialized_len ( )
94
60
}
95
-
96
- Ok ( if length <= 0x3f {
97
- 1
98
- } else if length <= 0x3fff {
99
- 2
100
- } else if length <= 0x3fff_ffff {
101
- 4
102
- } else {
103
- 8
104
- } )
105
61
}
106
62
107
- #[ inline( always) ]
108
- pub fn write_variable_length ( content_length : usize ) -> Result < Vec < u8 > , Error > {
109
- let len_len = length_encoding_bytes ( content_length. try_into ( ) ?) ?;
110
- if !cfg ! ( fuzzing) {
111
- debug_assert ! ( len_len <= 8 , "Invalid vector len_len {len_len}" ) ;
112
- }
113
- if len_len > 8 {
114
- return Err ( Error :: LibraryError ) ;
115
- }
116
- let mut length_bytes = vec ! [ 0u8 ; len_len] ;
117
- match len_len {
118
- 1 => length_bytes[ 0 ] = 0x00 ,
119
- 2 => length_bytes[ 0 ] = 0x40 ,
120
- 4 => length_bytes[ 0 ] = 0x80 ,
121
- 8 => length_bytes[ 0 ] = 0xc0 ,
122
- _ => {
123
- if !cfg ! ( fuzzing) {
124
- debug_assert ! ( false , "Invalid vector len_len {len_len}" ) ;
125
- }
126
- return Err ( Error :: InvalidVectorLength ) ;
127
- }
128
- }
129
- let mut len = content_length;
130
- for b in length_bytes. iter_mut ( ) . rev ( ) {
131
- * b |= ( len & 0xFF ) as u8 ;
132
- len >>= 8 ;
63
+ impl DeserializeBytes for ContentLength {
64
+ fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
65
+ let ( value, remainder) = super :: TlsVarInt :: tls_deserialize_bytes ( bytes) ?;
66
+ Ok ( ( Self ( value) , remainder) )
133
67
}
134
-
135
- Ok ( length_bytes)
136
68
}
137
69
138
70
impl < T : Size > Size for Vec < T > {
@@ -152,7 +84,9 @@ impl<T: Size> Size for &Vec<T> {
152
84
impl < T : DeserializeBytes > DeserializeBytes for Vec < T > {
153
85
#[ inline( always) ]
154
86
fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
155
- let ( ( length, len_len) , mut remainder) = read_variable_length_bytes ( bytes) ?;
87
+ let ( length, mut remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
88
+ let len_len = length. 0 . bytes_len ( ) ;
89
+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
156
90
157
91
if length == 0 {
158
92
// An empty vector.
@@ -178,11 +112,12 @@ impl<T: SerializeBytes> SerializeBytes for &[T] {
178
112
// This requires more computations but the other option would be to buffer
179
113
// the entire content, which can end up requiring a lot of memory.
180
114
let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
181
- let mut length = write_variable_length ( content_length) ?;
182
- let len_len = length. len ( ) ;
115
+ let length = ContentLength :: from_usize ( content_length) ?;
116
+ let len_len = length. 0 . bytes_len ( ) ;
183
117
184
118
let mut out = Vec :: with_capacity ( content_length + len_len) ;
185
- out. append ( & mut length) ;
119
+ out. resize ( len_len, 0 ) ;
120
+ length. 0 . write_bytes ( & mut out) ?;
186
121
187
122
// Serialize the elements
188
123
for e in self . iter ( ) {
@@ -214,11 +149,13 @@ impl<T: Size> Size for &[T] {
214
149
#[ inline( always) ]
215
150
fn tls_serialized_len ( & self ) -> usize {
216
151
let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
217
- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
218
- // We can't do anything about the error unless we change the trait.
219
- // Let's say there's no content for now.
220
- 0
221
- } ) ;
152
+ let len_len = ContentLength :: from_usize ( content_length)
153
+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
154
+ . unwrap_or ( {
155
+ // We can't do anything about the error unless we change the trait.
156
+ // Let's say there's no content for now.
157
+ 0
158
+ } ) ;
222
159
content_length + len_len
223
160
}
224
161
}
@@ -327,10 +264,12 @@ impl From<VLBytes> for Vec<u8> {
327
264
#[ inline( always) ]
328
265
fn tls_serialize_bytes_len ( bytes : & [ u8 ] ) -> usize {
329
266
let content_length = bytes. len ( ) ;
330
- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
331
- // We can't do anything about the error. Let's say there's no content.
332
- 0
333
- } ) ;
267
+ let len_len = ContentLength :: from_usize ( content_length)
268
+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
269
+ . unwrap_or ( {
270
+ // We can't do anything about the error. Let's say there's no content.
271
+ 0
272
+ } ) ;
334
273
content_length + len_len
335
274
}
336
275
@@ -344,22 +283,13 @@ impl Size for VLBytes {
344
283
impl DeserializeBytes for VLBytes {
345
284
#[ inline( always) ]
346
285
fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
347
- let ( ( length, _) , remainder) = read_variable_length_bytes ( bytes) ?;
286
+ let ( length, remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
287
+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
288
+
348
289
if length == 0 {
349
290
return Ok ( ( Self :: new ( vec ! [ ] ) , remainder) ) ;
350
291
}
351
292
352
- if !cfg ! ( fuzzing) {
353
- debug_assert ! (
354
- length <= MAX_LEN as usize ,
355
- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
356
- ) ;
357
- }
358
- if length > MAX_LEN as usize {
359
- return Err ( Error :: DecodingError ( format ! (
360
- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
361
- ) ) ) ;
362
- }
363
293
match remainder. get ( ..length) . ok_or ( Error :: EndOfStream ) {
364
294
Ok ( vec) => Ok ( ( Self { vec : vec. to_vec ( ) } , & remainder[ length..] ) ) ,
365
295
Err ( _e) => {
@@ -422,6 +352,19 @@ pub mod rw {
422
352
use super :: * ;
423
353
use crate :: { Deserialize , Serialize } ;
424
354
355
+ impl Deserialize for ContentLength {
356
+ fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
357
+ ContentLength :: new ( crate :: TlsVarInt :: tls_deserialize ( bytes) ?)
358
+ }
359
+ }
360
+
361
+ impl Serialize for ContentLength {
362
+ #[ inline( always) ]
363
+ fn tls_serialize < W : std:: io:: Write > ( & self , writer : & mut W ) -> Result < usize , Error > {
364
+ self . 0 . tls_serialize ( writer)
365
+ }
366
+ }
367
+
425
368
/// Read the length of a variable-length vector.
426
369
///
427
370
/// This function assumes that the reader is at the start of a variable length
@@ -430,26 +373,9 @@ pub mod rw {
430
373
/// The length and number of bytes read are returned.
431
374
#[ inline]
432
375
pub fn read_length < R : std:: io:: Read > ( bytes : & mut R ) -> Result < ( usize , usize ) , Error > {
433
- // The length is encoded in the first two bits of the first byte.
434
- let mut len_len_byte = [ 0u8 ; 1 ] ;
435
- if bytes. read ( & mut len_len_byte) ? == 0 {
436
- // There must be at least one byte for the length.
437
- // If we don't even have a length byte, this is not a valid
438
- // variable-length encoded vector.
439
- return Err ( Error :: InvalidVectorLength ) ;
440
- }
441
- let len_len_byte = len_len_byte[ 0 ] ;
442
-
443
- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
444
-
445
- for _ in 1 ..len_len {
446
- let mut next = [ 0u8 ; 1 ] ;
447
- bytes. read_exact ( & mut next) ?;
448
- length = ( length << 8 ) + usize:: from ( next[ 0 ] ) ;
449
- }
450
-
451
- check_min_length ( length, len_len) ?;
452
-
376
+ let length = ContentLength :: tls_deserialize ( bytes) ?;
377
+ let len_len = length. 0 . bytes_len ( ) ;
378
+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
453
379
Ok ( ( length, len_len) )
454
380
}
455
381
@@ -479,10 +405,7 @@ pub mod rw {
479
405
writer : & mut W ,
480
406
content_length : usize ,
481
407
) -> Result < usize , Error > {
482
- let buf = super :: write_variable_length ( content_length) ?;
483
- let buf_len = buf. len ( ) ;
484
- writer. write_all ( & buf) ?;
485
- Ok ( buf_len)
408
+ ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer)
486
409
}
487
410
488
411
impl < T : Serialize + std:: fmt:: Debug > Serialize for Vec < T > {
@@ -538,19 +461,7 @@ mod rw_bytes {
538
461
// large and write it out.
539
462
let content_length = bytes. len ( ) ;
540
463
541
- if !cfg ! ( fuzzing) {
542
- debug_assert ! (
543
- content_length as u64 <= MAX_LEN ,
544
- "Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}" ,
545
- ) ;
546
- }
547
- if content_length as u64 > MAX_LEN {
548
- return Err ( Error :: InvalidVectorLength ) ;
549
- }
550
-
551
- let length_bytes = write_variable_length ( content_length) ?;
552
- let len_len = length_bytes. len ( ) ;
553
- writer. write_all ( & length_bytes) ?;
464
+ let len_len = ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer) ?;
554
465
555
466
// Now serialize the elements
556
467
writer. write_all ( bytes) ?;
@@ -574,24 +485,14 @@ mod rw_bytes {
574
485
575
486
impl Deserialize for VLBytes {
576
487
fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
577
- let ( length, _) = rw:: read_length ( bytes) ?;
578
- if length == 0 {
488
+ let length = ContentLength :: tls_deserialize ( bytes) ?;
489
+
490
+ if length. 0 . value ( ) == 0 {
579
491
return Ok ( Self :: new ( vec ! [ ] ) ) ;
580
492
}
581
493
582
- if !cfg ! ( fuzzing) {
583
- debug_assert ! (
584
- length <= MAX_LEN as usize ,
585
- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
586
- ) ;
587
- }
588
- if length > MAX_LEN as usize {
589
- return Err ( Error :: DecodingError ( format ! (
590
- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
591
- ) ) ) ;
592
- }
593
494
let mut result = Self {
594
- vec : vec ! [ 0u8 ; length] ,
495
+ vec : vec ! [ 0u8 ; length. 0 . value ( ) . try_into ( ) ? ] ,
595
496
} ;
596
497
bytes. read_exact ( result. vec . as_mut_slice ( ) ) ?;
597
498
Ok ( result)
@@ -682,7 +583,7 @@ impl<'a> Arbitrary<'a> for VLBytes {
682
583
// We generate an arbitrary `Vec<u8>` ...
683
584
let mut vec = Vec :: arbitrary ( u) ?;
684
585
// ... and truncate it to `MAX_LEN`.
685
- vec. truncate ( MAX_LEN as usize ) ;
586
+ vec. truncate ( ContentLength :: MAX as usize ) ;
686
587
// We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing,
687
588
// but better make sure that we generate valid instances.
688
589
0 commit comments