Skip to content

Commit 72fc9bc

Browse files
claudealex
authored andcommitted
feat: Add support for Arc<str> as a string type
This is modeled on the support for multiple types for bytes.
1 parent 1fae49e commit 72fc9bc

File tree

9 files changed

+373
-80
lines changed

9 files changed

+373
-80
lines changed

prost-build/src/code_generator.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use prost_types::{
1515
};
1616

1717
use crate::ast::{Comments, Method, Service};
18+
use crate::collections::StringType;
1819
use crate::context::Context;
1920
use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel};
2021
use crate::Config;
@@ -451,6 +452,17 @@ impl<'b> CodeGenerator<'_, 'b> {
451452
.push_str(&format!(" = {:?}", bytes_type.annotation()));
452453
}
453454

455+
if type_ == Type::String {
456+
let string_type = self
457+
.context
458+
.string_type(fq_message_name, field.descriptor.name());
459+
// Only emit the annotation if it's not the default type
460+
if string_type != StringType::String {
461+
self.buf
462+
.push_str(&format!(" = {:?}", string_type.annotation()));
463+
}
464+
}
465+
454466
match field.descriptor.label() {
455467
Label::Optional => {
456468
if optional {
@@ -982,7 +994,17 @@ impl<'b> CodeGenerator<'_, 'b> {
982994
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
983995
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
984996
Type::Bool => String::from("bool"),
985-
Type::String => format!("{}::alloc::string::String", self.context.prost_path()),
997+
Type::String => {
998+
let string_type = self.context.string_type(fq_message_name, field.name());
999+
match string_type {
1000+
StringType::String => {
1001+
format!("{}::alloc::string::String", self.context.prost_path())
1002+
}
1003+
StringType::ArcStr => {
1004+
format!("{}::alloc::sync::Arc<str>", self.context.prost_path())
1005+
}
1006+
}
1007+
}
9861008
Type::Bytes => self
9871009
.context
9881010
.bytes_type(fq_message_name, field.name())

prost-build/src/collections.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,24 @@ impl BytesType {
5555
}
5656
}
5757
}
58+
59+
/// The string type to output for Protobuf `string` fields.
60+
#[non_exhaustive]
61+
#[derive(Default, Clone, Copy, Debug, PartialEq)]
62+
pub(crate) enum StringType {
63+
/// The [`prost::alloc::string::String`] type.
64+
#[default]
65+
String,
66+
/// The [`std::sync::Arc<str>`] type.
67+
ArcStr,
68+
}
69+
70+
impl StringType {
71+
/// The `prost-derive` annotation type corresponding to the string type.
72+
pub fn annotation(&self) -> &'static str {
73+
match self {
74+
StringType::String => "string",
75+
StringType::ArcStr => "arc_str",
76+
}
77+
}
78+
}

prost-build/src/config.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::BytesType;
2323
use crate::MapType;
2424
use crate::Module;
2525
use crate::ServiceGenerator;
26+
use crate::StringType;
2627

2728
/// Configuration options for Protobuf code generation.
2829
///
@@ -32,6 +33,7 @@ pub struct Config {
3233
pub(crate) service_generator: Option<Box<dyn ServiceGenerator>>,
3334
pub(crate) map_type: PathMap<MapType>,
3435
pub(crate) bytes_type: PathMap<BytesType>,
36+
pub(crate) string_type: PathMap<StringType>,
3537
pub(crate) type_attributes: PathMap<String>,
3638
pub(crate) message_attributes: PathMap<String>,
3739
pub(crate) enum_attributes: PathMap<String>,
@@ -184,6 +186,65 @@ impl Config {
184186
self
185187
}
186188

189+
/// Configure the code generator to generate Rust [`Arc<str>`](std::sync::Arc) fields for Protobuf
190+
/// [`string`][2] type fields.
191+
///
192+
/// # Arguments
193+
///
194+
/// **`paths`** - paths to specific fields, messages, or packages which should use a Rust
195+
/// `Arc<str>` for Protobuf `string` fields. Paths are specified in terms of the Protobuf type
196+
/// name (not the generated Rust type name). Paths with a leading `.` are treated as fully
197+
/// qualified names. Paths without a leading `.` are treated as relative, and are suffix
198+
/// matched on the fully qualified field name. If a Protobuf string field matches any of the
199+
/// paths, a Rust `Arc<str>` field is generated instead of the default [`String`].
200+
///
201+
/// The matching is done on the Protobuf names, before converting to Rust-friendly casing
202+
/// standards.
203+
///
204+
/// # Examples
205+
///
206+
/// ```rust
207+
/// # let mut config = prost_build::Config::new();
208+
/// // Match a specific field in a message type.
209+
/// config.arc_str(&[".my_messages.MyMessageType.my_string_field"]);
210+
///
211+
/// // Match all string fields in a message type.
212+
/// config.arc_str(&[".my_messages.MyMessageType"]);
213+
///
214+
/// // Match all string fields in a package.
215+
/// config.arc_str(&[".my_messages"]);
216+
///
217+
/// // Match all string fields.
218+
/// config.arc_str(&["."]);
219+
///
220+
/// // Match all string fields in a nested message.
221+
/// config.arc_str(&[".my_messages.MyMessageType.MyNestedMessageType"]);
222+
///
223+
/// // Match all fields named 'my_string_field'.
224+
/// config.arc_str(&["my_string_field"]);
225+
///
226+
/// // Match all fields named 'my_string_field' in messages named 'MyMessageType', regardless of
227+
/// // package or nesting.
228+
/// config.arc_str(&["MyMessageType.my_string_field"]);
229+
///
230+
/// // Match all fields named 'my_string_field', and all fields in the 'foo.bar' package.
231+
/// config.arc_str(&["my_string_field", ".foo.bar"]);
232+
/// ```
233+
///
234+
/// [2]: https://protobuf.dev/programming-guides/proto3/#scalar
235+
pub fn arc_str<I, S>(&mut self, paths: I) -> &mut Self
236+
where
237+
I: IntoIterator<Item = S>,
238+
S: AsRef<str>,
239+
{
240+
self.string_type.clear();
241+
for matcher in paths {
242+
self.string_type
243+
.insert(matcher.as_ref().to_string(), StringType::ArcStr);
244+
}
245+
self
246+
}
247+
187248
/// Add additional attribute to matched fields.
188249
///
189250
/// # Arguments
@@ -1198,6 +1259,7 @@ impl default::Default for Config {
11981259
service_generator: None,
11991260
map_type: PathMap::default(),
12001261
bytes_type: PathMap::default(),
1262+
string_type: PathMap::default(),
12011263
type_attributes: PathMap::default(),
12021264
message_attributes: PathMap::default(),
12031265
enum_attributes: PathMap::default(),

prost-build/src/context.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use prost_types::{
77

88
use crate::extern_paths::ExternPaths;
99
use crate::message_graph::MessageGraph;
10-
use crate::{BytesType, Config, MapType, ServiceGenerator};
10+
use crate::{BytesType, Config, MapType, ServiceGenerator, StringType};
1111

1212
/// The context providing all the global information needed to generate code.
1313
/// It also provides a more disciplined access to Config
@@ -110,6 +110,15 @@ impl<'a> Context<'a> {
110110
.unwrap_or_default()
111111
}
112112

113+
/// Returns the string type configured for the named message field.
114+
pub(crate) fn string_type(&self, fq_message_name: &str, field_name: &str) -> StringType {
115+
self.config
116+
.string_type
117+
.get_first_field(fq_message_name, field_name)
118+
.copied()
119+
.unwrap_or_default()
120+
}
121+
113122
/// Returns the map type configured for the named message field.
114123
pub(crate) fn map_type(&self, fq_message_name: &str, field_name: &str) -> MapType {
115124
self.config

prost-build/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ mod ast;
143143
pub use crate::ast::{Comments, Method, Service};
144144

145145
mod collections;
146-
pub(crate) use collections::{BytesType, MapType};
146+
pub(crate) use collections::{BytesType, MapType, StringType};
147147

148148
mod code_generator;
149149
mod context;

prost-derive/src/field/map.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
366366
| scalar::Ty::Sfixed32
367367
| scalar::Ty::Sfixed64
368368
| scalar::Ty::Bool
369-
| scalar::Ty::String => Ok(ty),
369+
| scalar::Ty::String(..) => Ok(ty),
370370
_ => bail!("invalid map key type: {}", s),
371371
}
372372
}

prost-derive/src/field/scalar.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,13 @@ impl Field {
118118
match self.kind {
119119
Kind::Plain(ref default) => {
120120
let default = default.typed();
121+
let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
122+
quote!(#ident.as_ref() != #default)
123+
} else {
124+
quote!(#ident != #default)
125+
};
121126
quote! {
122-
if #ident != #default {
127+
if #comparison {
123128
#encode_fn(#tag, &#ident, buf);
124129
}
125130
}
@@ -172,8 +177,13 @@ impl Field {
172177
match self.kind {
173178
Kind::Plain(ref default) => {
174179
let default = default.typed();
180+
let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
181+
quote!(#ident.as_ref() != #default)
182+
} else {
183+
quote!(#ident != #default)
184+
};
175185
quote! {
176-
if #ident != #default {
186+
if #comparison {
177187
#encoded_len_fn(#tag, &#ident)
178188
} else {
179189
0
@@ -194,7 +204,10 @@ impl Field {
194204
Kind::Plain(ref default) | Kind::Required(ref default) => {
195205
let default = default.typed();
196206
match self.ty {
197-
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
207+
Ty::String(StringTy::ArcStr) => {
208+
quote!(#ident = ::core::default::Default::default())
209+
}
210+
Ty::String(..) | Ty::Bytes(..) => quote!(#ident.clear()),
198211
_ => quote!(#ident = #default),
199212
}
200213
}
@@ -206,7 +219,17 @@ impl Field {
206219
/// Returns an expression which evaluates to the default value of the field.
207220
pub fn default(&self, prost_path: &Path) -> TokenStream {
208221
match self.kind {
209-
Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(prost_path),
222+
Kind::Plain(ref value) | Kind::Required(ref value) => {
223+
// Special handling for Arc<str> default value
224+
if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
225+
if matches!(value, DefaultValue::String(s) if s.is_empty()) {
226+
return quote!(#prost_path::alloc::sync::Arc::from(""));
227+
} else if let DefaultValue::String(s) = value {
228+
return quote!(#prost_path::alloc::sync::Arc::from(#s));
229+
}
230+
}
231+
value.owned(prost_path)
232+
}
210233
Kind::Optional(_) => quote!(::core::option::Option::None),
211234
Kind::Repeated | Kind::Packed => quote!(#prost_path::alloc::vec::Vec::new()),
212235
}
@@ -393,7 +416,7 @@ pub enum Ty {
393416
Sfixed32,
394417
Sfixed64,
395418
Bool,
396-
String,
419+
String(StringTy),
397420
Bytes(BytesTy),
398421
Enumeration(Path),
399422
}
@@ -404,6 +427,12 @@ pub enum BytesTy {
404427
Bytes,
405428
}
406429

430+
#[derive(Clone, Debug, PartialEq, Eq)]
431+
pub enum StringTy {
432+
String,
433+
ArcStr,
434+
}
435+
407436
impl BytesTy {
408437
fn try_from_str(s: &str) -> Result<Self, Error> {
409438
match s {
@@ -421,6 +450,23 @@ impl BytesTy {
421450
}
422451
}
423452

453+
impl StringTy {
454+
fn try_from_str(s: &str) -> Result<Self, Error> {
455+
match s {
456+
"string" => Ok(StringTy::String),
457+
"arc_str" => Ok(StringTy::ArcStr),
458+
_ => bail!("Invalid string type: {}", s),
459+
}
460+
}
461+
462+
fn rust_type(&self, prost_path: &Path) -> TokenStream {
463+
match self {
464+
StringTy::String => quote! { #prost_path::alloc::string::String },
465+
StringTy::ArcStr => quote! { #prost_path::alloc::sync::Arc<str> },
466+
}
467+
}
468+
}
469+
424470
impl Ty {
425471
pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
426472
let ty = match *attr {
@@ -437,8 +483,17 @@ impl Ty {
437483
Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
438484
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
439485
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
440-
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
486+
Meta::Path(ref name) if name.is_ident("string") => Ty::String(StringTy::String),
441487
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
488+
Meta::NameValue(MetaNameValue {
489+
ref path,
490+
value:
491+
Expr::Lit(ExprLit {
492+
lit: Lit::Str(ref l),
493+
..
494+
}),
495+
..
496+
}) if path.is_ident("string") => Ty::String(StringTy::try_from_str(&l.value())?),
442497
Meta::NameValue(MetaNameValue {
443498
ref path,
444499
value:
@@ -482,7 +537,7 @@ impl Ty {
482537
"sfixed32" => Ty::Sfixed32,
483538
"sfixed64" => Ty::Sfixed64,
484539
"bool" => Ty::Bool,
485-
"string" => Ty::String,
540+
"string" => Ty::String(StringTy::String),
486541
"bytes" => Ty::Bytes(BytesTy::Vec),
487542
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
488543
let s = &s[enumeration_len..].trim();
@@ -518,7 +573,7 @@ impl Ty {
518573
Ty::Sfixed32 => "sfixed32",
519574
Ty::Sfixed64 => "sfixed64",
520575
Ty::Bool => "bool",
521-
Ty::String => "string",
576+
Ty::String(..) => "string",
522577
Ty::Bytes(..) => "bytes",
523578
Ty::Enumeration(..) => "enum",
524579
}
@@ -527,7 +582,7 @@ impl Ty {
527582
// TODO: rename to 'owned_type'.
528583
pub fn rust_type(&self, prost_path: &Path) -> TokenStream {
529584
match self {
530-
Ty::String => quote!(#prost_path::alloc::string::String),
585+
Ty::String(ty) => ty.rust_type(prost_path),
531586
Ty::Bytes(ty) => ty.rust_type(prost_path),
532587
_ => self.rust_ref_type(),
533588
}
@@ -549,7 +604,7 @@ impl Ty {
549604
Ty::Sfixed32 => quote!(i32),
550605
Ty::Sfixed64 => quote!(i64),
551606
Ty::Bool => quote!(bool),
552-
Ty::String => quote!(&str),
607+
Ty::String(..) => quote!(&str),
553608
Ty::Bytes(..) => quote!(&[u8]),
554609
Ty::Enumeration(..) => quote!(i32),
555610
}
@@ -564,7 +619,7 @@ impl Ty {
564619

565620
/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
566621
pub fn is_numeric(&self) -> bool {
567-
!matches!(self, Ty::String | Ty::Bytes(..))
622+
!matches!(self, Ty::String(..) | Ty::Bytes(..))
568623
}
569624
}
570625

@@ -660,7 +715,7 @@ impl DefaultValue {
660715
Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),
661716

662717
Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
663-
Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
718+
Lit::Str(ref lit) if matches!(*ty, Ty::String(..)) => DefaultValue::String(lit.value()),
664719
Lit::ByteStr(ref lit)
665720
if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
666721
{
@@ -769,7 +824,7 @@ impl DefaultValue {
769824
Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),
770825

771826
Ty::Bool => DefaultValue::Bool(false),
772-
Ty::String => DefaultValue::String(String::new()),
827+
Ty::String(..) => DefaultValue::String(String::new()),
773828
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
774829
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
775830
}

0 commit comments

Comments
 (0)