Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use prost_types::{
};

use crate::ast::{Comments, Method, Service};
use crate::collections::StringType;
use crate::context::Context;
use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel};
use crate::Config;
Expand Down Expand Up @@ -451,6 +452,17 @@ impl<'b> CodeGenerator<'_, 'b> {
.push_str(&format!(" = {:?}", bytes_type.annotation()));
}

if type_ == Type::String {
let string_type = self
.context
.string_type(fq_message_name, field.descriptor.name());
// Only emit the annotation if it's not the default type
if string_type != StringType::String {
self.buf
.push_str(&format!(" = {:?}", string_type.annotation()));
}
}

match field.descriptor.label() {
Label::Optional => {
if optional {
Expand Down Expand Up @@ -982,7 +994,17 @@ impl<'b> CodeGenerator<'_, 'b> {
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => format!("{}::alloc::string::String", self.context.prost_path()),
Type::String => {
let string_type = self.context.string_type(fq_message_name, field.name());
match string_type {
StringType::String => {
format!("{}::alloc::string::String", self.context.prost_path())
}
StringType::ArcStr => {
format!("{}::alloc::sync::Arc<str>", self.context.prost_path())
}
}
}
Type::Bytes => self
.context
.bytes_type(fq_message_name, field.name())
Expand Down
21 changes: 21 additions & 0 deletions prost-build/src/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,24 @@ impl BytesType {
}
}
}

/// The string type to output for Protobuf `string` fields.
#[non_exhaustive]
#[derive(Default, Clone, Copy, Debug, PartialEq)]
pub(crate) enum StringType {
/// The [`prost::alloc::string::String`] type.
#[default]
String,
/// The [`std::sync::Arc<str>`] type.
ArcStr,
}

impl StringType {
/// The `prost-derive` annotation type corresponding to the string type.
pub fn annotation(&self) -> &'static str {
match self {
StringType::String => "string",
StringType::ArcStr => "arc_str",
}
}
}
62 changes: 62 additions & 0 deletions prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::BytesType;
use crate::MapType;
use crate::Module;
use crate::ServiceGenerator;
use crate::StringType;

/// Configuration options for Protobuf code generation.
///
Expand All @@ -32,6 +33,7 @@ pub struct Config {
pub(crate) service_generator: Option<Box<dyn ServiceGenerator>>,
pub(crate) map_type: PathMap<MapType>,
pub(crate) bytes_type: PathMap<BytesType>,
pub(crate) string_type: PathMap<StringType>,
pub(crate) type_attributes: PathMap<String>,
pub(crate) message_attributes: PathMap<String>,
pub(crate) enum_attributes: PathMap<String>,
Expand Down Expand Up @@ -184,6 +186,65 @@ impl Config {
self
}

/// Configure the code generator to generate Rust [`Arc<str>`](std::sync::Arc) fields for Protobuf
/// [`string`][2] type fields.
///
/// # Arguments
///
/// **`paths`** - paths to specific fields, messages, or packages which should use a Rust
/// `Arc<str>` for Protobuf `string` fields. Paths are specified in terms of the Protobuf type
/// name (not the generated Rust type name). Paths with a leading `.` are treated as fully
/// qualified names. Paths without a leading `.` are treated as relative, and are suffix
/// matched on the fully qualified field name. If a Protobuf string field matches any of the
/// paths, a Rust `Arc<str>` field is generated instead of the default [`String`].
///
/// The matching is done on the Protobuf names, before converting to Rust-friendly casing
/// standards.
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// // Match a specific field in a message type.
/// config.arc_str(&[".my_messages.MyMessageType.my_string_field"]);
///
/// // Match all string fields in a message type.
/// config.arc_str(&[".my_messages.MyMessageType"]);
///
/// // Match all string fields in a package.
/// config.arc_str(&[".my_messages"]);
///
/// // Match all string fields.
/// config.arc_str(&["."]);
///
/// // Match all string fields in a nested message.
/// config.arc_str(&[".my_messages.MyMessageType.MyNestedMessageType"]);
///
/// // Match all fields named 'my_string_field'.
/// config.arc_str(&["my_string_field"]);
///
/// // Match all fields named 'my_string_field' in messages named 'MyMessageType', regardless of
/// // package or nesting.
/// config.arc_str(&["MyMessageType.my_string_field"]);
///
/// // Match all fields named 'my_string_field', and all fields in the 'foo.bar' package.
/// config.arc_str(&["my_string_field", ".foo.bar"]);
/// ```
///
/// [2]: https://protobuf.dev/programming-guides/proto3/#scalar
pub fn arc_str<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.string_type.clear();
for matcher in paths {
self.string_type
.insert(matcher.as_ref().to_string(), StringType::ArcStr);
}
self
}

/// Add additional attribute to matched fields.
///
/// # Arguments
Expand Down Expand Up @@ -1198,6 +1259,7 @@ impl default::Default for Config {
service_generator: None,
map_type: PathMap::default(),
bytes_type: PathMap::default(),
string_type: PathMap::default(),
type_attributes: PathMap::default(),
message_attributes: PathMap::default(),
enum_attributes: PathMap::default(),
Expand Down
11 changes: 10 additions & 1 deletion prost-build/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use prost_types::{

use crate::extern_paths::ExternPaths;
use crate::message_graph::MessageGraph;
use crate::{BytesType, Config, MapType, ServiceGenerator};
use crate::{BytesType, Config, MapType, ServiceGenerator, StringType};

/// The context providing all the global information needed to generate code.
/// It also provides a more disciplined access to Config
Expand Down Expand Up @@ -110,6 +110,15 @@ impl<'a> Context<'a> {
.unwrap_or_default()
}

/// Returns the string type configured for the named message field.
pub(crate) fn string_type(&self, fq_message_name: &str, field_name: &str) -> StringType {
self.config
.string_type
.get_first_field(fq_message_name, field_name)
.copied()
.unwrap_or_default()
}

/// Returns the map type configured for the named message field.
pub(crate) fn map_type(&self, fq_message_name: &str, field_name: &str) -> MapType {
self.config
Expand Down
2 changes: 1 addition & 1 deletion prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ mod ast;
pub use crate::ast::{Comments, Method, Service};

mod collections;
pub(crate) use collections::{BytesType, MapType};
pub(crate) use collections::{BytesType, MapType, StringType};

mod code_generator;
mod context;
Expand Down
2 changes: 1 addition & 1 deletion prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::String => Ok(ty),
| scalar::Ty::String(..) => Ok(ty),
_ => bail!("invalid map key type: {s}"),
}
}
Expand Down
81 changes: 68 additions & 13 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ impl Field {
match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
quote!(#ident.as_ref() != #default)
} else {
quote!(#ident != #default)
};
quote! {
if #ident != #default {
if #comparison {
#encode_fn(#tag, &#ident, buf);
}
}
Expand Down Expand Up @@ -172,8 +177,13 @@ impl Field {
match self.kind {
Kind::Plain(ref default) => {
let default = default.typed();
let comparison = if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
quote!(#ident.as_ref() != #default)
} else {
quote!(#ident != #default)
};
quote! {
if #ident != #default {
if #comparison {
#encoded_len_fn(#tag, &#ident)
} else {
0
Expand All @@ -194,7 +204,10 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
Ty::String(StringTy::ArcStr) => {
quote!(#ident = ::core::default::Default::default())
}
Ty::String(..) | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
}
Expand All @@ -206,7 +219,17 @@ impl Field {
/// Returns an expression which evaluates to the default value of the field.
pub fn default(&self, prost_path: &Path) -> TokenStream {
match self.kind {
Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(prost_path),
Kind::Plain(ref value) | Kind::Required(ref value) => {
// Special handling for Arc<str> default value
if matches!(self.ty, Ty::String(StringTy::ArcStr)) {
if matches!(value, DefaultValue::String(s) if s.is_empty()) {
return quote!(#prost_path::alloc::sync::Arc::from(""));
} else if let DefaultValue::String(s) = value {
return quote!(#prost_path::alloc::sync::Arc::from(#s));
}
}
value.owned(prost_path)
}
Kind::Optional(_) => quote!(::core::option::Option::None),
Kind::Repeated | Kind::Packed => quote!(#prost_path::alloc::vec::Vec::new()),
}
Expand Down Expand Up @@ -393,7 +416,7 @@ pub enum Ty {
Sfixed32,
Sfixed64,
Bool,
String,
String(StringTy),
Bytes(BytesTy),
Enumeration(Path),
}
Expand All @@ -404,6 +427,12 @@ pub enum BytesTy {
Bytes,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum StringTy {
String,
ArcStr,
}

impl BytesTy {
fn try_from_str(s: &str) -> Result<Self, Error> {
match s {
Expand All @@ -421,6 +450,23 @@ impl BytesTy {
}
}

impl StringTy {
fn try_from_str(s: &str) -> Result<Self, Error> {
match s {
"string" => Ok(StringTy::String),
"arc_str" => Ok(StringTy::ArcStr),
_ => bail!("Invalid string type: {}", s),
}
}

fn rust_type(&self, prost_path: &Path) -> TokenStream {
match self {
StringTy::String => quote! { #prost_path::alloc::string::String },
StringTy::ArcStr => quote! { #prost_path::alloc::sync::Arc<str> },
}
}
}

impl Ty {
pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
let ty = match *attr {
Expand All @@ -437,8 +483,17 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("string") => Ty::String(StringTy::String),
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
value:
Expr::Lit(ExprLit {
lit: Lit::Str(ref l),
..
}),
..
}) if path.is_ident("string") => Ty::String(StringTy::try_from_str(&l.value())?),
Meta::NameValue(MetaNameValue {
ref path,
value:
Expand Down Expand Up @@ -482,7 +537,7 @@ impl Ty {
"sfixed32" => Ty::Sfixed32,
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"string" => Ty::String,
"string" => Ty::String(StringTy::String),
"bytes" => Ty::Bytes(BytesTy::Vec),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
let s = &s[enumeration_len..].trim();
Expand Down Expand Up @@ -518,7 +573,7 @@ impl Ty {
Ty::Sfixed32 => "sfixed32",
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::String(..) => "string",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
Expand All @@ -527,7 +582,7 @@ impl Ty {
// TODO: rename to 'owned_type'.
pub fn rust_type(&self, prost_path: &Path) -> TokenStream {
match self {
Ty::String => quote!(#prost_path::alloc::string::String),
Ty::String(ty) => ty.rust_type(prost_path),
Ty::Bytes(ty) => ty.rust_type(prost_path),
_ => self.rust_ref_type(),
}
Expand All @@ -549,7 +604,7 @@ impl Ty {
Ty::Sfixed32 => quote!(i32),
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::String(..) => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
Expand All @@ -564,7 +619,7 @@ impl Ty {

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::String(..) | Ty::Bytes(..))
}
}

Expand Down Expand Up @@ -660,7 +715,7 @@ impl DefaultValue {
Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),

Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
Lit::Str(ref lit) if matches!(*ty, Ty::String(..)) => DefaultValue::String(lit.value()),
Lit::ByteStr(ref lit)
if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
{
Expand Down Expand Up @@ -769,7 +824,7 @@ impl DefaultValue {
Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::String(..) => DefaultValue::String(String::new()),
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
Expand Down
Loading
Loading