From acb3d7a6b1a71063dddd545506192a56f0fcdf56 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 8 Nov 2025 02:58:29 +0000 Subject: [PATCH] feat: Add support for Arc as a string type This is modeled on the support for multiple types for bytes. --- prost-build/src/code_generator.rs | 24 +++- prost-build/src/collections.rs | 21 +++ prost-build/src/config.rs | 62 +++++++++ prost-build/src/context.rs | 11 +- prost-build/src/lib.rs | 2 +- prost-derive/src/field/map.rs | 2 +- prost-derive/src/field/scalar.rs | 81 ++++++++++-- prost/src/encoding.rs | 206 +++++++++++++++++++++--------- tests/src/message_encoding.rs | 44 ++++++- 9 files changed, 373 insertions(+), 80 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 1f22acbe2..1d7002e0a 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -15,6 +15,7 @@ use prost_types::{ }; use crate::ast::{Comments, Method, Service}; +use crate::collections::StringType; use crate::context::Context; use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel}; use crate::Config; @@ -451,6 +452,17 @@ impl<'b> CodeGenerator<'_, 'b> { .push_str(&format!(" = {:?}", bytes_type.annotation())); } + if type_ == Type::String { + let string_type = self + .context + .string_type(fq_message_name, field.descriptor.name()); + // Only emit the annotation if it's not the default type + if string_type != StringType::String { + self.buf + .push_str(&format!(" = {:?}", string_type.annotation())); + } + } + match field.descriptor.label() { Label::Optional => { if optional { @@ -982,7 +994,17 @@ impl<'b> CodeGenerator<'_, 'b> { Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), Type::Bool => String::from("bool"), - Type::String => format!("{}::alloc::string::String", self.context.prost_path()), + Type::String => { + let string_type = self.context.string_type(fq_message_name, field.name()); + match string_type { + StringType::String => { + format!("{}::alloc::string::String", self.context.prost_path()) + } + StringType::ArcStr => { + format!("{}::alloc::sync::Arc", self.context.prost_path()) + } + } + } Type::Bytes => self .context .bytes_type(fq_message_name, field.name()) diff --git a/prost-build/src/collections.rs b/prost-build/src/collections.rs index e55625cf3..f7a833702 100644 --- a/prost-build/src/collections.rs +++ b/prost-build/src/collections.rs @@ -55,3 +55,24 @@ impl BytesType { } } } + +/// The string type to output for Protobuf `string` fields. +#[non_exhaustive] +#[derive(Default, Clone, Copy, Debug, PartialEq)] +pub(crate) enum StringType { + /// The [`prost::alloc::string::String`] type. + #[default] + String, + /// The [`std::sync::Arc`] type. + ArcStr, +} + +impl StringType { + /// The `prost-derive` annotation type corresponding to the string type. + pub fn annotation(&self) -> &'static str { + match self { + StringType::String => "string", + StringType::ArcStr => "arc_str", + } + } +} diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 8deba5c7a..e75785222 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -23,6 +23,7 @@ use crate::BytesType; use crate::MapType; use crate::Module; use crate::ServiceGenerator; +use crate::StringType; /// Configuration options for Protobuf code generation. /// @@ -32,6 +33,7 @@ pub struct Config { pub(crate) service_generator: Option>, pub(crate) map_type: PathMap, pub(crate) bytes_type: PathMap, + pub(crate) string_type: PathMap, pub(crate) type_attributes: PathMap, pub(crate) message_attributes: PathMap, pub(crate) enum_attributes: PathMap, @@ -184,6 +186,65 @@ impl Config { self } + /// Configure the code generator to generate Rust [`Arc`](std::sync::Arc) fields for Protobuf + /// [`string`][2] type fields. + /// + /// # Arguments + /// + /// **`paths`** - paths to specific fields, messages, or packages which should use a Rust + /// `Arc` for Protobuf `string` fields. Paths are specified in terms of the Protobuf type + /// name (not the generated Rust type name). Paths with a leading `.` are treated as fully + /// qualified names. Paths without a leading `.` are treated as relative, and are suffix + /// matched on the fully qualified field name. If a Protobuf string field matches any of the + /// paths, a Rust `Arc` field is generated instead of the default [`String`]. + /// + /// The matching is done on the Protobuf names, before converting to Rust-friendly casing + /// standards. + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// // Match a specific field in a message type. + /// config.arc_str(&[".my_messages.MyMessageType.my_string_field"]); + /// + /// // Match all string fields in a message type. + /// config.arc_str(&[".my_messages.MyMessageType"]); + /// + /// // Match all string fields in a package. + /// config.arc_str(&[".my_messages"]); + /// + /// // Match all string fields. + /// config.arc_str(&["."]); + /// + /// // Match all string fields in a nested message. + /// config.arc_str(&[".my_messages.MyMessageType.MyNestedMessageType"]); + /// + /// // Match all fields named 'my_string_field'. + /// config.arc_str(&["my_string_field"]); + /// + /// // Match all fields named 'my_string_field' in messages named 'MyMessageType', regardless of + /// // package or nesting. + /// config.arc_str(&["MyMessageType.my_string_field"]); + /// + /// // Match all fields named 'my_string_field', and all fields in the 'foo.bar' package. + /// config.arc_str(&["my_string_field", ".foo.bar"]); + /// ``` + /// + /// [2]: https://protobuf.dev/programming-guides/proto3/#scalar + pub fn arc_str(&mut self, paths: I) -> &mut Self + where + I: IntoIterator, + S: AsRef, + { + self.string_type.clear(); + for matcher in paths { + self.string_type + .insert(matcher.as_ref().to_string(), StringType::ArcStr); + } + self + } + /// Add additional attribute to matched fields. /// /// # Arguments @@ -1198,6 +1259,7 @@ impl default::Default for Config { service_generator: None, map_type: PathMap::default(), bytes_type: PathMap::default(), + string_type: PathMap::default(), type_attributes: PathMap::default(), message_attributes: PathMap::default(), enum_attributes: PathMap::default(), diff --git a/prost-build/src/context.rs b/prost-build/src/context.rs index 7fecc8dfc..38e01dccd 100644 --- a/prost-build/src/context.rs +++ b/prost-build/src/context.rs @@ -7,7 +7,7 @@ use prost_types::{ use crate::extern_paths::ExternPaths; use crate::message_graph::MessageGraph; -use crate::{BytesType, Config, MapType, ServiceGenerator}; +use crate::{BytesType, Config, MapType, ServiceGenerator, StringType}; /// The context providing all the global information needed to generate code. /// It also provides a more disciplined access to Config @@ -110,6 +110,15 @@ impl<'a> Context<'a> { .unwrap_or_default() } + /// Returns the string type configured for the named message field. + pub(crate) fn string_type(&self, fq_message_name: &str, field_name: &str) -> StringType { + self.config + .string_type + .get_first_field(fq_message_name, field_name) + .copied() + .unwrap_or_default() + } + /// Returns the map type configured for the named message field. pub(crate) fn map_type(&self, fq_message_name: &str, field_name: &str) -> MapType { self.config diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 599630a0f..ac1c2f131 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -143,7 +143,7 @@ mod ast; pub use crate::ast::{Comments, Method, Service}; mod collections; -pub(crate) use collections::{BytesType, MapType}; +pub(crate) use collections::{BytesType, MapType, StringType}; mod code_generator; mod context; diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index c5f36e23c..56f25467f 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -366,7 +366,7 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool - | scalar::Ty::String => Ok(ty), + | scalar::Ty::String(..) => Ok(ty), _ => bail!("invalid map key type: {s}"), } } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index 25596cbc0..95141523d 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -118,8 +118,13 @@ impl Field { match self.kind { Kind::Plain(ref default) => { let default = default.typed(); + let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) { + quote!(#ident.as_ref() != #default) + } else { + quote!(#ident != #default) + }; quote! { - if #ident != #default { + if #comparison { #encode_fn(#tag, &#ident, buf); } } @@ -172,8 +177,13 @@ impl Field { match self.kind { Kind::Plain(ref default) => { let default = default.typed(); + let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) { + quote!(#ident.as_ref() != #default) + } else { + quote!(#ident != #default) + }; quote! { - if #ident != #default { + if #comparison { #encoded_len_fn(#tag, &#ident) } else { 0 @@ -194,7 +204,10 @@ impl Field { Kind::Plain(ref default) | Kind::Required(ref default) => { let default = default.typed(); match self.ty { - Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), + Ty::String(StringTy::ArcStr) => { + quote!(#ident = ::core::default::Default::default()) + } + Ty::String(..) | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), } } @@ -206,7 +219,17 @@ impl Field { /// Returns an expression which evaluates to the default value of the field. pub fn default(&self, prost_path: &Path) -> TokenStream { match self.kind { - Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(prost_path), + Kind::Plain(ref value) | Kind::Required(ref value) => { + // Special handling for Arc default value + if matches!(self.ty, Ty::String(StringTy::ArcStr)) { + if matches!(value, DefaultValue::String(s) if s.is_empty()) { + return quote!(#prost_path::alloc::sync::Arc::from("")); + } else if let DefaultValue::String(s) = value { + return quote!(#prost_path::alloc::sync::Arc::from(#s)); + } + } + value.owned(prost_path) + } Kind::Optional(_) => quote!(::core::option::Option::None), Kind::Repeated | Kind::Packed => quote!(#prost_path::alloc::vec::Vec::new()), } @@ -393,7 +416,7 @@ pub enum Ty { Sfixed32, Sfixed64, Bool, - String, + String(StringTy), Bytes(BytesTy), Enumeration(Path), } @@ -404,6 +427,12 @@ pub enum BytesTy { Bytes, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum StringTy { + String, + ArcStr, +} + impl BytesTy { fn try_from_str(s: &str) -> Result { match s { @@ -421,6 +450,23 @@ impl BytesTy { } } +impl StringTy { + fn try_from_str(s: &str) -> Result { + match s { + "string" => Ok(StringTy::String), + "arc_str" => Ok(StringTy::ArcStr), + _ => bail!("Invalid string type: {}", s), + } + } + + fn rust_type(&self, prost_path: &Path) -> TokenStream { + match self { + StringTy::String => quote! { #prost_path::alloc::string::String }, + StringTy::ArcStr => quote! { #prost_path::alloc::sync::Arc }, + } + } +} + impl Ty { pub fn from_attr(attr: &Meta) -> Result, Error> { let ty = match *attr { @@ -437,8 +483,17 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32, Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, - Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("string") => Ty::String(StringTy::String), Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), + Meta::NameValue(MetaNameValue { + ref path, + value: + Expr::Lit(ExprLit { + lit: Lit::Str(ref l), + .. + }), + .. + }) if path.is_ident("string") => Ty::String(StringTy::try_from_str(&l.value())?), Meta::NameValue(MetaNameValue { ref path, value: @@ -482,7 +537,7 @@ impl Ty { "sfixed32" => Ty::Sfixed32, "sfixed64" => Ty::Sfixed64, "bool" => Ty::Bool, - "string" => Ty::String, + "string" => Ty::String(StringTy::String), "bytes" => Ty::Bytes(BytesTy::Vec), s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { let s = &s[enumeration_len..].trim(); @@ -518,7 +573,7 @@ impl Ty { Ty::Sfixed32 => "sfixed32", Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", - Ty::String => "string", + Ty::String(..) => "string", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -527,7 +582,7 @@ impl Ty { // TODO: rename to 'owned_type'. pub fn rust_type(&self, prost_path: &Path) -> TokenStream { match self { - Ty::String => quote!(#prost_path::alloc::string::String), + Ty::String(ty) => ty.rust_type(prost_path), Ty::Bytes(ty) => ty.rust_type(prost_path), _ => self.rust_ref_type(), } @@ -549,7 +604,7 @@ impl Ty { Ty::Sfixed32 => quote!(i32), Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), - Ty::String => quote!(&str), + Ty::String(..) => quote!(&str), Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } @@ -564,7 +619,7 @@ impl Ty { /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::String(..) | Ty::Bytes(..)) } } @@ -660,7 +715,7 @@ impl DefaultValue { Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?), Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value), - Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()), + Lit::Str(ref lit) if matches!(*ty, Ty::String(..)) => DefaultValue::String(lit.value()), Lit::ByteStr(ref lit) if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) => { @@ -769,7 +824,7 @@ impl DefaultValue { Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0), Ty::Bool => DefaultValue::Bool(false), - Ty::String => DefaultValue::String(String::new()), + Ty::String(..) => DefaultValue::String(String::new()), Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 1794b0bfb..7f9c81848 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -7,9 +7,9 @@ use alloc::collections::BTreeMap; use alloc::string::String; +use alloc::sync::Arc; use alloc::vec::Vec; -use core::mem; -use core::str; +use core::{mem, str}; use ::bytes::{Buf, BufMut, Bytes}; @@ -557,7 +557,7 @@ macro_rules! length_delimited { pub mod string { use super::*; - pub fn encode(tag: u32, value: &String, buf: &mut impl BufMut) { + pub fn encode(tag: u32, value: &impl StringAdapter, buf: &mut impl BufMut) { encode_key(tag, WireType::LengthDelimited, buf); encode_varint(value.len() as u64, buf); buf.put_slice(value.as_bytes()); @@ -565,46 +565,21 @@ pub mod string { pub fn merge( wire_type: WireType, - value: &mut String, + value: &mut impl StringAdapter, buf: &mut impl Buf, - ctx: DecodeContext, + _ctx: DecodeContext, ) -> Result<(), DecodeError> { - // ## Unsafety - // - // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 - // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the - // string is cleared, so as to avoid leaking a string field with invalid data. - // - // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe - // alternative of temporarily swapping an empty `String` into the field, because it results - // in up to 10% better performance on the protobuf message decoding benchmarks. - // - // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into - // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or - // in the buf implementation, a drop guard is used. - unsafe { - struct DropGuard<'a>(&'a mut Vec); - impl Drop for DropGuard<'_> { - #[inline] - fn drop(&mut self) { - self.0.clear(); - } - } - - let drop_guard = DropGuard(value.as_mut_vec()); - bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?; - match str::from_utf8(drop_guard.0) { - Ok(_) => { - // Success; do not clear the bytes. - mem::forget(drop_guard); - Ok(()) - } - Err(_) => Err(DecodeErrorKind::InvalidString.into()), - } + check_wire_type(WireType::LengthDelimited, wire_type)?; + let len = decode_varint(buf)?; + if len > buf.remaining() as u64 { + return Err(DecodeErrorKind::BufferUnderflow.into()); } + let len = len as usize; + + value.replace_with(buf.take(len)) } - length_delimited!(String); + length_delimited!(impl StringAdapter); #[cfg(test)] mod test { @@ -615,16 +590,27 @@ pub mod string { proptest! { #[test] - fn check(value: String, tag in MIN_TAG..=MAX_TAG) { - super::test::check_type(value, tag, WireType::LengthDelimited, + fn check_string(value: String, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type::(value, tag, WireType::LengthDelimited, encode, merge, encoded_len)?; } #[test] - fn check_repeated(value: Vec, tag in MIN_TAG..=MAX_TAG) { + fn check_arc_str(value: Arc, tag in MIN_TAG..=MAX_TAG) { + super::test::check_type::, Arc>(value, tag, WireType::LengthDelimited, + encode, merge, encoded_len)?; + } + #[test] + fn check_repeated_string(value: Vec, tag in MIN_TAG..=MAX_TAG) { super::test::check_collection_type(value, tag, WireType::LengthDelimited, encode_repeated, merge_repeated, encoded_len_repeated)?; } + #[test] + fn check_repeated_arc_str(values: Vec>, tag in MIN_TAG..=MAX_TAG) { + super::test::check_collection_type(values, tag, WireType::LengthDelimited, + encode_repeated, merge_repeated, + encoded_len_repeated)?; + } } } } @@ -717,30 +703,10 @@ pub mod bytes { // [1]: https://protobuf.dev/programming-guides/encoding/#last-one-wins // // This is intended for A and B both being Bytes so it is zero-copy. - // Some combinations of A and B types may cause a double-copy, - // in which case merge_one_copy() should be used instead. value.replace_with(buf.copy_to_bytes(len)); Ok(()) } - pub(super) fn merge_one_copy( - wire_type: WireType, - value: &mut impl BytesAdapter, - buf: &mut impl Buf, - _ctx: DecodeContext, - ) -> Result<(), DecodeError> { - check_wire_type(WireType::LengthDelimited, wire_type)?; - let len = decode_varint(buf)?; - if len > buf.remaining() as u64 { - return Err(DecodeErrorKind::BufferUnderflow.into()); - } - let len = len as usize; - - // If we must copy, make sure to copy only once. - value.replace_with(buf.take(len)); - Ok(()) - } - length_delimited!(impl BytesAdapter); #[cfg(test)] @@ -782,6 +748,124 @@ pub mod bytes { } } +pub trait StringAdapter: sealed_string::StringAdapter {} + +mod sealed_string { + use super::{Buf, DecodeError}; + + pub trait StringAdapter: Default + Sized + 'static { + fn len(&self) -> usize; + + fn as_bytes(&self) -> &[u8]; + + /// Replace contents of this string with the contents from a buffer. + /// Returns an error if the buffer contains invalid UTF-8. + fn replace_with(&mut self, buf: impl Buf) -> Result<(), DecodeError>; + + fn is_empty(&self) -> bool { + self.len() == 0 + } + } +} + +impl StringAdapter for String {} + +impl sealed_string::StringAdapter for String { + fn len(&self) -> usize { + String::len(self) + } + + fn as_bytes(&self) -> &[u8] { + String::as_bytes(self) + } + + fn replace_with(&mut self, buf: impl Buf) -> Result<(), DecodeError> { + // ## Unsafety + // + // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 + // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the + // string is cleared, so as to avoid leaking a string field with invalid data. + // + // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe + // alternative of temporarily swapping an empty `String` into the field, because it results + // in up to 10% better performance on the protobuf message decoding benchmarks. + // + // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into + // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or + // in the buf implementation, a drop guard is used. + unsafe { + struct DropGuard<'a>(&'a mut alloc::vec::Vec); + impl Drop for DropGuard<'_> { + #[inline] + fn drop(&mut self) { + self.0.clear(); + } + } + + let vec = self.as_mut_vec(); + let drop_guard = DropGuard(vec); + drop_guard.0.clear(); + drop_guard.0.reserve(buf.remaining()); + drop_guard.0.put(buf); + + match core::str::from_utf8(drop_guard.0) { + Ok(_) => { + core::mem::forget(drop_guard); + Ok(()) + } + Err(_) => Err(DecodeErrorKind::InvalidString.into()), + } + } + } +} + +impl StringAdapter for Arc {} + +impl sealed_string::StringAdapter for Arc { + fn len(&self) -> usize { + ::len(self) + } + + fn as_bytes(&self) -> &[u8] { + ::as_bytes(self) + } + + fn replace_with(&mut self, mut buf: impl Buf) -> Result<(), DecodeError> { + // We're about to do some serious contortions to ensure we perform only + // a single allocation and a single copy. Various parts of this will be + // more ergonomic once more helpers for zero-allocation and working + // with [MaybeUninit] are stabalized. + + // Allocate space for `b` + let mut arc: Arc<[mem::MaybeUninit]> = Arc::new_uninit_slice(buf.remaining()); + // We just created the `Arc`, so we can get a `&mut` to the data. + let data = Arc::get_mut(&mut arc).unwrap(); + // Zero initialize, so that we can safely create a slice of this data. + for el in data.iter_mut() { + el.write(0); + } + // SAFETY: we just zero-initialized `data` and now the arc is fully + // filled in. + let arc = unsafe { + buf.copy_to_slice(core::slice::from_raw_parts_mut( + data.as_mut_ptr() as *mut u8, + data.len(), + )); + arc.assume_init() + }; + if str::from_utf8(&arc).is_err() { + return Err(DecodeErrorKind::InvalidString.into()); + } + + // SAFETY: [u8] and str have the same representation, therefore we are + // allowed to convert an `Arc<[u8]>` to an `Arc`, provided we have + // verified the contents are valid UTF-8 (which we just did). + *self = unsafe { Arc::from_raw(Arc::into_raw(arc) as *const str) }; + + Ok(()) + } +} + pub mod message { use super::*; diff --git a/tests/src/message_encoding.rs b/tests/src/message_encoding.rs index cdce05b06..19bcd469a 100644 --- a/tests/src/message_encoding.rs +++ b/tests/src/message_encoding.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use prost::alloc::collections::BTreeMap; +use prost::alloc::sync::Arc; use prost::alloc::vec; #[cfg(not(feature = "std"))] use prost::alloc::{borrow::ToOwned, string::String, vec::Vec}; @@ -68,9 +69,11 @@ pub struct ScalarTypes { pub _bool: bool, #[prost(string, tag = "014")] pub string: String, - #[prost(bytes = "vec", tag = "015")] + #[prost(string = "arc_str", tag = "015")] + pub string_arc: Arc, + #[prost(bytes = "vec", tag = "016")] pub bytes_vec: Vec, - #[prost(bytes = "bytes", tag = "016")] + #[prost(bytes = "bytes", tag = "017")] pub bytes_buf: Bytes, #[prost(int32, required, tag = "101")] @@ -432,3 +435,40 @@ fn roundtrip() { }; check_message(&msg); } + +#[derive(Clone, PartialEq, Message)] +pub struct ArcStrMessage { + #[prost(string = "arc_str", tag = "1")] + pub name: Arc, + + #[prost(string = "arc_str", optional, tag = "2")] + pub optional_name: Option>, + + #[prost(string = "arc_str", repeated, tag = "3")] + pub repeated_names: Vec>, +} + +#[test] +fn test_arc_str() { + let values = [ + ArcStrMessage { + name: Arc::from("hello world"), + optional_name: Some(Arc::from("optional")), + repeated_names: vec![Arc::from("first"), Arc::from("second"), Arc::from("third")], + }, + ArcStrMessage { + name: Arc::from(""), + optional_name: None, + repeated_names: Vec::new(), + }, + ArcStrMessage { + name: Arc::from("Hello 世界! 🦀"), + optional_name: Some(Arc::from("Unicode: é, ñ, ö")), + repeated_names: vec![Arc::from("日本語"), Arc::from("中文"), Arc::from("한글")], + }, + ]; + + for v in values { + check_message(&v); + } +}