From 1f4fb1b0bc3776e6962b379edfbeb5327883e4d7 Mon Sep 17 00:00:00 2001 From: Mark Abspoel Date: Thu, 12 Sep 2024 18:18:09 +0200 Subject: [PATCH] Add base64 for ByteBuf when serializer is human-readable --- Cargo.toml | 1 + src/bytebuf.rs | 59 ++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 02e8748..9c6cd82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ alloc = ["serde/alloc"] [dependencies] serde = { version = "1.0.166", default-features = false } +base64 = "0.21.4" [dev-dependencies] bincode = "1.3.3" diff --git a/src/bytebuf.rs b/src/bytebuf.rs index 8dbb769..430401e 100644 --- a/src/bytebuf.rs +++ b/src/bytebuf.rs @@ -4,6 +4,9 @@ use core::fmt::{self, Debug}; use core::hash::{Hash, Hasher}; use core::ops::{Deref, DerefMut}; +pub use base64::DecodeError; +use base64::prelude::{Engine as _, BASE64_STANDARD}; + #[cfg(feature = "alloc")] use alloc::boxed::Box; #[cfg(feature = "alloc")] @@ -85,6 +88,10 @@ impl ByteBuf { pub fn into_iter(self) -> as IntoIterator>::IntoIter { self.bytes.into_iter() } + + fn decode(bytes: &[u8]) -> Result { + Ok(Self { bytes: BASE64_STANDARD.decode(bytes)? }) + } } impl Debug for ByteBuf { @@ -193,7 +200,12 @@ impl Serialize for ByteBuf { where S: Serializer, { - serializer.serialize_bytes(&self.bytes) + if serializer.is_human_readable() { + let encoded = BASE64_STANDARD.encode(&self.bytes); + String::serialize(&encoded, serializer) + } else { + serializer.serialize_bytes(&self.bytes) + } } } @@ -249,11 +261,54 @@ impl<'de> Visitor<'de> for ByteBufVisitor { } } +struct Base64Visitor; + +impl<'de> Visitor<'de> for Base64Visitor { + type Value = ByteBuf; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("base64-encoded string") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: Error, + { + ByteBuf::decode(v).map_err(serde::de::Error::custom) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: Error, + { + ByteBuf::decode(v.as_slice()).map_err(serde::de::Error::custom) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + ByteBuf::decode(v.as_bytes()).map_err(serde::de::Error::custom) + } + + fn visit_string(self, v: String) -> Result + where + E: Error, + { + ByteBuf::decode(v.as_bytes()).map_err(serde::de::Error::custom) + } +} + + impl<'de> Deserialize<'de> for ByteBuf { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { - deserializer.deserialize_byte_buf(ByteBufVisitor) + if deserializer.is_human_readable() { + deserializer.deserialize_str(Base64Visitor) + } else { + deserializer.deserialize_byte_buf(ByteBufVisitor) + } } }