From c1cdd11904136a763e6258a2e89a8f618d9e9618 Mon Sep 17 00:00:00 2001 From: ben avrahami Date: Mon, 10 Feb 2025 20:20:29 +0200 Subject: [PATCH] allow re-reading from deserialized that failed due to type mismatch --- rmp-serde/src/decode.rs | 81 ++++++++++++++++++++++++++++++++++++-- rmpv-tests/tests/decode.rs | 58 ++++++++++++++++++++++++++- 2 files changed, 134 insertions(+), 5 deletions(-) diff --git a/rmp-serde/src/decode.rs b/rmp-serde/src/decode.rs index 4f3bee3..864dcb0 100644 --- a/rmp-serde/src/decode.rs +++ b/rmp-serde/src/decode.rs @@ -332,7 +332,10 @@ fn read_i128_marker<'de, R: ReadSlice<'de>>(marker: Marker, rd: &mut R) -> Resul read_128_buf(rd, len)? }, Marker::FixArray(len) => read_128_buf(rd, len)?, - marker => return Err(Error::TypeMismatch(marker)), + marker => { + consume_unexpected_value(rd, marker)?; + return Err(Error::TypeMismatch(marker)) + }, }) } @@ -396,7 +399,7 @@ fn read_u32(rd: &mut R) -> Result { .map_err(Error::InvalidDataRead) } -fn ext_len(rd: &mut R, marker: Marker) -> Result { +fn ext_len<'de, R: Read + ReadSlice<'de>>(rd: &mut R, marker: Marker) -> Result { Ok(match marker { Marker::FixExt1 => 1, Marker::FixExt2 => 2, @@ -406,7 +409,10 @@ fn ext_len(rd: &mut R, marker: Marker) -> Result { Marker::Ext8 => u32::from(read_u8(rd)?), Marker::Ext16 => u32::from(read_u16(rd)?), Marker::Ext32 => read_u32(rd)?, - _ => return Err(Error::TypeMismatch(marker)), + _ => { + consume_unexpected_value(rd, marker)?; + return Err(Error::TypeMismatch(marker)) + } }) } @@ -521,10 +527,77 @@ fn any_num<'de, R: ReadSlice<'de>, V: Visitor<'de>>(rd: &mut R, visitor: V, mark Marker::I64 => visitor.visit_i64(rd.read_data_i64()?), Marker::F32 => visitor.visit_f32(rd.read_data_f32()?), Marker::F64 => visitor.visit_f64(rd.read_data_f64()?), - other_marker => Err(Error::TypeMismatch(other_marker)), + other_marker => { + consume_unexpected_value(rd, marker)?; + Err(Error::TypeMismatch(other_marker)) + }, } } +fn consume_unexpected_values<'de, R: ReadSlice<'de>>(rd: &mut R, count: usize) -> Result<(), Error>{ + for _ in 0..count { + let marker = rmp::decode::read_marker(rd)?; + consume_unexpected_value(rd, marker)?; + } + Ok(()) +} + +fn consume_unexpected_value<'de, R: ReadSlice<'de>>(rd: &mut R, marker: Marker) -> Result<(), Error> { + // This function is for when we read a marker that indicates a type we don't expect to see. + // but in order for future reads to be correct, we need to consume the data indicated by the marker. + // note that the only errors we expect to arise from here are invalid data reads, + // which the decoder is generally unable to recover from + match marker { + Marker::Null => (), + Marker::True | Marker::False => (), + Marker::FixPos(_) | Marker::FixNeg(_) => (), + Marker::U8 => {rd.read_data_u8()?;} + Marker::U16 => {rd.read_data_u16()?;} + Marker::U32 => {rd.read_data_u32()?;} + Marker::U64 => {rd.read_data_u64()?;} + Marker::I8 => {rd.read_data_i8()?;} + Marker::I16 => {rd.read_data_i16()?;} + Marker::I32 => {rd.read_data_i32()?;} + Marker::I64 => {rd.read_data_i64()?;} + Marker::F32 => {rd.read_data_f32()?;} + Marker::F64 => {rd.read_data_f64()?;} + Marker::FixStr(len) => {rd.read_slice(len as usize).map_err(Error::InvalidDataRead)?;} + Marker::Str8 | Marker::Bin8 => {let len = rd.read_data_u8()?; rd.read_slice(len as usize).map_err(Error::InvalidDataRead)?;} + Marker::Str16 | Marker::Bin16 => {let len = rd.read_data_u16()?; rd.read_slice(len as usize).map_err(Error::InvalidDataRead)?;} + Marker::Str32 | Marker::Bin32 => {let len = rd.read_data_u32()?; rd.read_slice(len as usize).map_err(Error::InvalidDataRead)?;} + Marker::FixArray(len) => { + consume_unexpected_values(rd, len as usize)? + } + Marker::Array16 => { + let len = rd.read_data_u16()?; + consume_unexpected_values(rd, len as usize)? + } + Marker::Array32 => { + let len = rd.read_data_u32()?; + consume_unexpected_values(rd, len as usize)? + } + Marker::FixMap(len) => { + consume_unexpected_values(rd, len as usize * 2)? + } + Marker::Map16 => { + let len = rd.read_data_u16()?; + consume_unexpected_values(rd, len as usize * 2)? + } + Marker::Map32 => { + let len = rd.read_data_u32()?; + consume_unexpected_values(rd, len as usize * 2)? + } + Marker::FixExt1 | Marker::FixExt2 | Marker::FixExt4 | Marker::FixExt8 | Marker::FixExt16 | + Marker::Ext8 | Marker::Ext16 | Marker::Ext32 + => { + let len = ext_len(rd, marker)?; + rd.read_slice(len as usize + 1).map_err(Error::InvalidDataRead)?; + } + Marker::Reserved => (), + }; + Ok(()) +} + impl<'de, R: ReadSlice<'de>, C: SerializerConfig> Deserializer { fn any_inner>(&mut self, visitor: V, allow_bytes: bool) -> Result { let marker = self.take_or_read_marker()?; diff --git a/rmpv-tests/tests/decode.rs b/rmpv-tests/tests/decode.rs index 8289b3e..00a6739 100644 --- a/rmpv-tests/tests/decode.rs +++ b/rmpv-tests/tests/decode.rs @@ -1,6 +1,6 @@ use serde::Deserialize; use serde_bytes::ByteBuf; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use rmpv::decode; use rmpv::ext::from_value; @@ -311,3 +311,59 @@ fn pass_tuple_struct_from_ext() { from_value(Value::Ext(42, vec![255])).unwrap() ); } + +#[derive(Debug, PartialEq)] +enum MightFail{ + Ok(T), + Failed, +} + +impl<'de, T:serde::de::Deserialize<'de>> serde::de::Deserialize<'de> for MightFail { + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + match T::deserialize(deserializer){ + Ok(v) => Ok(MightFail::Ok(v)), + Err(_) => Ok(MightFail::Failed), + } + } +} + +#[test] +fn pass_failing_elements() { + let buffer = rmp_serde::to_vec(&(42, + 41, + "hi there", + 43, + (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16), + 4.52, + 4, + 1u64 << 63, + "hi", // test fixed string + 65, + (1,2,3), // test fixed arrays + &[0xcc, 0x80][..], // test bin + HashMap::from([("a", 1), ("b", 2), ("c", 3)]), // test fixed map + HashMap::from([("a", 1), ("b", 2), ("c", 3), ("d", 4), ("e", 5), ("f", 6), ("g", 7), ("h", 8), ("i", 9), ("j", 10), ("k", 11), ("l", 12), ("m", 13), ("n", 14), ("o", 15), ("p", 16)]), // test map + 66, + )).unwrap(); + let deserialized: Vec> = rmp_serde::from_slice(&buffer).unwrap(); + assert_eq!(deserialized, vec![ + MightFail::Ok(42), + MightFail::Ok(41), + MightFail::Failed, + MightFail::Ok(43), + MightFail::Failed, + MightFail::Failed, + MightFail::Ok(4), + MightFail::Failed, + MightFail::Failed, + MightFail::Ok(65), + MightFail::Failed, + MightFail::Failed, + MightFail::Failed, + MightFail::Failed, + MightFail::Ok(66), + ]); +} \ No newline at end of file