Skip to content

Commit

Permalink
ProstDerive: Add support for CowStr and CowBytes
Browse files Browse the repository at this point in the history
Signed-off-by: Jon Doron <[email protected]>
  • Loading branch information
arilou committed Dec 10, 2024
1 parent 273ab33 commit 7a2ced4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
1 change: 1 addition & 0 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::CowStr
| scalar::Ty::String => Ok(ty),
_ => bail!("invalid map key type: {}", s),
}
Expand Down
38 changes: 35 additions & 3 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::CowStr => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed("")),
Ty::CowBytes => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed(&[])),
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
Expand Down Expand Up @@ -398,6 +400,8 @@ pub enum Ty {
Sfixed64,
Bool,
String,
CowStr,
CowBytes,
Bytes(BytesTy),
Enumeration(Path),
}
Expand Down Expand Up @@ -442,6 +446,8 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("cow_str") => Ty::CowStr,
Meta::Path(ref name) if name.is_ident("cow_bytes") => Ty::CowBytes,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
Expand Down Expand Up @@ -486,6 +492,8 @@ impl Ty {
"sfixed32" => Ty::Sfixed32,
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"cow_str" => Ty::CowStr,
"cow_bytes" => Ty::CowBytes,
"string" => Ty::String,
"bytes" => Ty::Bytes(BytesTy::Vec),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
Expand Down Expand Up @@ -523,6 +531,8 @@ impl Ty {
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::CowStr => "cow_str",
Ty::CowBytes => "cow_bytes",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
Expand All @@ -531,6 +541,8 @@ impl Ty {
// TODO: rename to 'owned_type'.
pub fn rust_type(&self) -> TokenStream {
match self {
Ty::CowStr => quote!(::prost::alloc::borrow::Cow<'a, str>),
Ty::CowBytes => quote!(::prost::alloc::borrow::Cow<'a, [u8]>),
Ty::String => quote!(::prost::alloc::string::String),
Ty::Bytes(ty) => ty.rust_type(),
_ => self.rust_ref_type(),
Expand All @@ -553,8 +565,8 @@ impl Ty {
Ty::Sfixed32 => quote!(i32),
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::Bytes(..) => quote!(&[u8]),
Ty::CowStr | Ty::String => quote!(&str),
Ty::CowBytes | Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
}
Expand All @@ -568,7 +580,7 @@ impl Ty {

/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
!matches!(self, Ty::String | Ty::Bytes(..))
!matches!(self, Ty::CowStr | Ty::String | Ty::CowBytes | Ty::Bytes(..))
}
}

Expand Down Expand Up @@ -610,6 +622,8 @@ pub enum DefaultValue {
U64(u64),
Bool(bool),
String(String),
CowStr(std::borrow::Cow<'static, str>),
CowBytes(Vec<u8>),
Bytes(Vec<u8>),
Enumeration(TokenStream),
Path(Path),
Expand Down Expand Up @@ -774,6 +788,8 @@ impl DefaultValue {

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::CowStr => DefaultValue::CowStr(std::borrow::Cow::Borrowed("")),
Ty::CowBytes => DefaultValue::CowBytes(Vec::new()),
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
Expand All @@ -785,6 +801,17 @@ impl DefaultValue {
quote!(::prost::alloc::string::String::new())
}
DefaultValue::String(ref value) => quote!(#value.into()),
DefaultValue::CowStr(ref value) if value.is_empty() => {
quote!(::prost::alloc::borrow::Cow::Borrowed(""))
}
DefaultValue::CowStr(ref value) => quote!(#value.into()),
DefaultValue::CowBytes(ref value) if value.is_empty() => {
quote!(::core::default::Default::default())
}
DefaultValue::CowBytes(ref value) => {
let lit = LitByteStr::new(value, Span::call_site());
quote!(#lit.as_ref().into())
}
DefaultValue::Bytes(ref value) if value.is_empty() => {
quote!(::core::default::Default::default())
}
Expand Down Expand Up @@ -817,6 +844,11 @@ impl ToTokens for DefaultValue {
DefaultValue::U64(value) => value.to_tokens(tokens),
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::CowStr(ref value) => value.to_tokens(tokens),
DefaultValue::CowBytes(ref value) => {
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
}
DefaultValue::Bytes(ref value) => {
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
Expand Down

0 comments on commit 7a2ced4

Please sign in to comment.