@@ -24,115 +24,47 @@ use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
2424
2525use crate :: { DeserializeBytes , Error , SerializeBytes , Size } ;
2626
27- #[ cfg( not( feature = "mls" ) ) ]
28- const MAX_LEN : u64 = ( 1 << 62 ) - 1 ;
29- #[ cfg( not( feature = "mls" ) ) ]
30- const MAX_LEN_LEN_LOG : usize = 3 ;
3127#[ cfg( feature = "mls" ) ]
32- const MAX_LEN : u64 = ( 1 << 30 ) - 1 ;
33- #[ cfg( feature = "mls" ) ]
34- const MAX_LEN_LEN_LOG : usize = 2 ;
35-
36- #[ inline( always) ]
37- fn check_min_length ( length : usize , len_len : usize ) -> Result < ( ) , Error > {
38- if cfg ! ( feature = "mls" ) {
39- // ensure that len_len is minimal for the given length
40- let min_len_len = length_encoding_bytes ( length as u64 ) ?;
41- if min_len_len != len_len {
42- return Err ( Error :: InvalidVectorLength ) ;
43- }
44- } ;
45- Ok ( ( ) )
46- }
28+ const MAX_MLS_LEN : u64 = ( 1 << 30 ) - 1 ;
4729
48- #[ inline( always) ]
49- fn calculate_length ( len_len_byte : u8 ) -> Result < ( usize , usize ) , Error > {
50- let length: usize = ( len_len_byte & 0x3F ) . into ( ) ;
51- let len_len_log = ( len_len_byte >> 6 ) . into ( ) ;
52- if !cfg ! ( fuzzing) {
53- debug_assert ! ( len_len_log <= MAX_LEN_LEN_LOG ) ;
54- }
55- if len_len_log > MAX_LEN_LEN_LOG {
56- return Err ( Error :: InvalidVectorLength ) ;
57- }
58- let len_len = match len_len_log {
59- 0 => 1 ,
60- 1 => 2 ,
61- 2 => 4 ,
62- 3 => 8 ,
63- _ => unreachable ! ( ) ,
64- } ;
65- Ok ( ( length, len_len) )
66- }
67-
68- #[ inline( always) ]
69- fn read_variable_length_bytes ( bytes : & [ u8 ] ) -> Result < ( ( usize , usize ) , & [ u8 ] ) , Error > {
70- // The length is encoded in the first two bits of the first byte.
30+ /// Thin wrapper around [`TlsVarInt`] representing the length of encoded vector
31+ /// content in bytes.
32+ ///
33+ /// When `mls` feature is enabled, the maximum length is limited to 30-bit.
34+ /// Otherwise, this type is no-op.
35+ struct ContentLength ( super :: TlsVarInt ) ;
7136
72- let ( len_len_byte, mut remainder) = u8:: tls_deserialize_bytes ( bytes) ?;
37+ impl ContentLength {
38+ #[ cfg( all( not( feature = "mls" ) , feature = "arbitrary" ) ) ]
39+ const MAX : u64 = crate :: TlsVarInt :: MAX ;
7340
74- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
41+ #[ cfg( feature = "mls" ) ]
42+ const MAX : u64 = MAX_MLS_LEN ;
7543
76- for _ in 1 ..len_len {
77- let ( next, next_remainder) = u8:: tls_deserialize_bytes ( remainder) ?;
78- remainder = next_remainder;
79- length = ( length << 8 ) + usize:: from ( next) ;
44+ fn new ( value : super :: TlsVarInt ) -> Result < Self , Error > {
45+ #[ cfg( feature = "mls" ) ]
46+ if Self :: MAX < value. value ( ) {
47+ return Err ( Error :: InvalidVectorLength ) ;
48+ }
49+ Ok ( Self ( value) )
8050 }
8151
82- check_min_length ( length , len_len ) ? ;
83-
84- Ok ( ( ( length , len_len ) , remainder ) )
52+ fn from_usize ( value : usize ) -> Result < Self , Error > {
53+ Self :: new ( super :: TlsVarInt :: try_new ( value . try_into ( ) ? ) ? )
54+ }
8555}
8656
87- #[ inline( always) ]
88- fn length_encoding_bytes ( length : u64 ) -> Result < usize , Error > {
89- if !cfg ! ( fuzzing) {
90- debug_assert ! ( length <= MAX_LEN ) ;
91- }
92- if length > MAX_LEN {
93- return Err ( Error :: InvalidVectorLength ) ;
57+ impl Size for ContentLength {
58+ fn tls_serialized_len ( & self ) -> usize {
59+ self . 0 . tls_serialized_len ( )
9460 }
95-
96- Ok ( if length <= 0x3f {
97- 1
98- } else if length <= 0x3fff {
99- 2
100- } else if length <= 0x3fff_ffff {
101- 4
102- } else {
103- 8
104- } )
10561}
10662
107- #[ inline( always) ]
108- pub fn write_variable_length ( content_length : usize ) -> Result < Vec < u8 > , Error > {
109- let len_len = length_encoding_bytes ( content_length. try_into ( ) ?) ?;
110- if !cfg ! ( fuzzing) {
111- debug_assert ! ( len_len <= 8 , "Invalid vector len_len {len_len}" ) ;
112- }
113- if len_len > 8 {
114- return Err ( Error :: LibraryError ) ;
115- }
116- let mut length_bytes = vec ! [ 0u8 ; len_len] ;
117- match len_len {
118- 1 => length_bytes[ 0 ] = 0x00 ,
119- 2 => length_bytes[ 0 ] = 0x40 ,
120- 4 => length_bytes[ 0 ] = 0x80 ,
121- 8 => length_bytes[ 0 ] = 0xc0 ,
122- _ => {
123- if !cfg ! ( fuzzing) {
124- debug_assert ! ( false , "Invalid vector len_len {len_len}" ) ;
125- }
126- return Err ( Error :: InvalidVectorLength ) ;
127- }
128- }
129- let mut len = content_length;
130- for b in length_bytes. iter_mut ( ) . rev ( ) {
131- * b |= ( len & 0xFF ) as u8 ;
132- len >>= 8 ;
63+ impl DeserializeBytes for ContentLength {
64+ fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
65+ let ( value, remainder) = super :: TlsVarInt :: tls_deserialize_bytes ( bytes) ?;
66+ Ok ( ( Self ( value) , remainder) )
13367 }
134-
135- Ok ( length_bytes)
13668}
13769
13870impl < T : Size > Size for Vec < T > {
@@ -152,7 +84,9 @@ impl<T: Size> Size for &Vec<T> {
15284impl < T : DeserializeBytes > DeserializeBytes for Vec < T > {
15385 #[ inline( always) ]
15486 fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
155- let ( ( length, len_len) , mut remainder) = read_variable_length_bytes ( bytes) ?;
87+ let ( length, mut remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
88+ let len_len = length. 0 . bytes_len ( ) ;
89+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
15690
15791 if length == 0 {
15892 // An empty vector.
@@ -178,11 +112,12 @@ impl<T: SerializeBytes> SerializeBytes for &[T] {
178112 // This requires more computations but the other option would be to buffer
179113 // the entire content, which can end up requiring a lot of memory.
180114 let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
181- let mut length = write_variable_length ( content_length) ?;
182- let len_len = length. len ( ) ;
115+ let length = ContentLength :: from_usize ( content_length) ?;
116+ let len_len = length. 0 . bytes_len ( ) ;
183117
184118 let mut out = Vec :: with_capacity ( content_length + len_len) ;
185- out. append ( & mut length) ;
119+ out. resize ( len_len, 0 ) ;
120+ length. 0 . write_bytes ( & mut out) ?;
186121
187122 // Serialize the elements
188123 for e in self . iter ( ) {
@@ -214,11 +149,13 @@ impl<T: Size> Size for &[T] {
214149 #[ inline( always) ]
215150 fn tls_serialized_len ( & self ) -> usize {
216151 let content_length = self . iter ( ) . fold ( 0 , |acc, e| acc + e. tls_serialized_len ( ) ) ;
217- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
218- // We can't do anything about the error unless we change the trait.
219- // Let's say there's no content for now.
220- 0
221- } ) ;
152+ let len_len = ContentLength :: from_usize ( content_length)
153+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
154+ . unwrap_or ( {
155+ // We can't do anything about the error unless we change the
156+ // trait. Let's say there's no content for now.
157+ 0
158+ } ) ;
222159 content_length + len_len
223160 }
224161}
@@ -332,10 +269,13 @@ impl From<VLBytes> for Vec<u8> {
332269#[ inline( always) ]
333270fn tls_serialize_bytes_len ( bytes : & [ u8 ] ) -> usize {
334271 let content_length = bytes. len ( ) ;
335- let len_len = length_encoding_bytes ( content_length as u64 ) . unwrap_or ( {
336- // We can't do anything about the error. Let's say there's no content.
337- 0
338- } ) ;
272+ let len_len = ContentLength :: from_usize ( content_length)
273+ . map ( |content_length| content_length. 0 . bytes_len ( ) )
274+ . unwrap_or ( {
275+ // We can't do anything about the error. Let's say there's no
276+ // content.
277+ 0
278+ } ) ;
339279 content_length + len_len
340280}
341281
@@ -349,22 +289,13 @@ impl Size for VLBytes {
349289impl DeserializeBytes for VLBytes {
350290 #[ inline( always) ]
351291 fn tls_deserialize_bytes ( bytes : & [ u8 ] ) -> Result < ( Self , & [ u8 ] ) , Error > {
352- let ( ( length, _) , remainder) = read_variable_length_bytes ( bytes) ?;
292+ let ( length, remainder) = ContentLength :: tls_deserialize_bytes ( bytes) ?;
293+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
294+
353295 if length == 0 {
354296 return Ok ( ( Self :: new ( vec ! [ ] ) , remainder) ) ;
355297 }
356298
357- if !cfg ! ( fuzzing) {
358- debug_assert ! (
359- length <= MAX_LEN as usize ,
360- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
361- ) ;
362- }
363- if length > MAX_LEN as usize {
364- return Err ( Error :: DecodingError ( format ! (
365- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
366- ) ) ) ;
367- }
368299 match remainder. get ( ..length) . ok_or ( Error :: EndOfStream ) {
369300 Ok ( vec) => Ok ( ( Self { vec : vec. to_vec ( ) } , & remainder[ length..] ) ) ,
370301 Err ( _e) => {
@@ -477,6 +408,20 @@ pub mod rw {
477408 use super :: * ;
478409 use crate :: { Deserialize , Serialize } ;
479410
411+ impl Deserialize for ContentLength {
412+ #[ inline( always) ]
413+ fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
414+ ContentLength :: new ( crate :: TlsVarInt :: tls_deserialize ( bytes) ?)
415+ }
416+ }
417+
418+ impl Serialize for ContentLength {
419+ #[ inline( always) ]
420+ fn tls_serialize < W : std:: io:: Write > ( & self , writer : & mut W ) -> Result < usize , Error > {
421+ self . 0 . tls_serialize ( writer)
422+ }
423+ }
424+
480425 /// Read the length of a variable-length vector.
481426 ///
482427 /// This function assumes that the reader is at the start of a variable length
@@ -485,26 +430,9 @@ pub mod rw {
485430 /// The length and number of bytes read are returned.
486431 #[ inline]
487432 pub fn read_length < R : std:: io:: Read > ( bytes : & mut R ) -> Result < ( usize , usize ) , Error > {
488- // The length is encoded in the first two bits of the first byte.
489- let mut len_len_byte = [ 0u8 ; 1 ] ;
490- if bytes. read ( & mut len_len_byte) ? == 0 {
491- // There must be at least one byte for the length.
492- // If we don't even have a length byte, this is not a valid
493- // variable-length encoded vector.
494- return Err ( Error :: InvalidVectorLength ) ;
495- }
496- let len_len_byte = len_len_byte[ 0 ] ;
497-
498- let ( mut length, len_len) = calculate_length ( len_len_byte) ?;
499-
500- for _ in 1 ..len_len {
501- let mut next = [ 0u8 ; 1 ] ;
502- bytes. read_exact ( & mut next) ?;
503- length = ( length << 8 ) + usize:: from ( next[ 0 ] ) ;
504- }
505-
506- check_min_length ( length, len_len) ?;
507-
433+ let length = ContentLength :: tls_deserialize ( bytes) ?;
434+ let len_len = length. 0 . bytes_len ( ) ;
435+ let length: usize = length. 0 . value ( ) . try_into ( ) ?;
508436 Ok ( ( length, len_len) )
509437 }
510438
@@ -534,10 +462,7 @@ pub mod rw {
534462 writer : & mut W ,
535463 content_length : usize ,
536464 ) -> Result < usize , Error > {
537- let buf = super :: write_variable_length ( content_length) ?;
538- let buf_len = buf. len ( ) ;
539- writer. write_all ( & buf) ?;
540- Ok ( buf_len)
465+ ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer)
541466 }
542467
543468 impl < T : Serialize + std:: fmt:: Debug > Serialize for Vec < T > {
@@ -593,19 +518,7 @@ mod rw_bytes {
593518 // large and write it out.
594519 let content_length = bytes. len ( ) ;
595520
596- if !cfg ! ( fuzzing) {
597- debug_assert ! (
598- content_length as u64 <= MAX_LEN ,
599- "Vector can't be encoded. It's too large. {content_length} >= {MAX_LEN}" ,
600- ) ;
601- }
602- if content_length as u64 > MAX_LEN {
603- return Err ( Error :: InvalidVectorLength ) ;
604- }
605-
606- let length_bytes = write_variable_length ( content_length) ?;
607- let len_len = length_bytes. len ( ) ;
608- writer. write_all ( & length_bytes) ?;
521+ let len_len = ContentLength :: from_usize ( content_length) ?. tls_serialize ( writer) ?;
609522
610523 // Now serialize the elements
611524 writer. write_all ( bytes) ?;
@@ -629,24 +542,14 @@ mod rw_bytes {
629542
630543 impl Deserialize for VLBytes {
631544 fn tls_deserialize < R : std:: io:: Read > ( bytes : & mut R ) -> Result < Self , Error > {
632- let ( length, _) = rw:: read_length ( bytes) ?;
633- if length == 0 {
545+ let length = ContentLength :: tls_deserialize ( bytes) ?;
546+
547+ if length. 0 . value ( ) == 0 {
634548 return Ok ( Self :: new ( vec ! [ ] ) ) ;
635549 }
636550
637- if !cfg ! ( fuzzing) {
638- debug_assert ! (
639- length <= MAX_LEN as usize ,
640- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
641- ) ;
642- }
643- if length > MAX_LEN as usize {
644- return Err ( Error :: DecodingError ( format ! (
645- "Trying to allocate {length} bytes. Only {MAX_LEN} allowed." ,
646- ) ) ) ;
647- }
648551 let mut result = Self {
649- vec : vec ! [ 0u8 ; length] ,
552+ vec : vec ! [ 0u8 ; length. 0 . value ( ) . try_into ( ) ? ] ,
650553 } ;
651554 bytes. read_exact ( result. vec . as_mut_slice ( ) ) ?;
652555 Ok ( result)
@@ -737,7 +640,7 @@ impl<'a> Arbitrary<'a> for VLBytes {
737640 // We generate an arbitrary `Vec<u8>` ...
738641 let mut vec = Vec :: arbitrary ( u) ?;
739642 // ... and truncate it to `MAX_LEN`.
740- vec. truncate ( MAX_LEN as usize ) ;
643+ vec. truncate ( ContentLength :: MAX as usize ) ;
741644 // We probably won't exceed `MAX_LEN` in practice, e.g., during fuzzing,
742645 // but better make sure that we generate valid instances.
743646
0 commit comments