diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..a8ccf0555 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool + | scalar::Ty::CowStr | scalar::Ty::String => Ok(ty), _ => bail!("invalid map key type: {}", s), } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index c2e870524..d6cae075d 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -194,6 +194,8 @@ impl Field { Kind::Plain(ref default) | Kind::Required(ref default) => { let default = default.typed(); match self.ty { + Ty::CowStr => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed("")), + Ty::CowBytes => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed(&[])), Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), } @@ -398,6 +400,8 @@ pub enum Ty { Sfixed64, Bool, String, + CowStr, + CowBytes, Bytes(BytesTy), Enumeration(Path), } @@ -442,6 +446,8 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("cow_str") => Ty::CowStr, + Meta::Path(ref name) if name.is_ident("cow_bytes") => Ty::CowBytes, Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), Meta::NameValue(MetaNameValue { ref path, @@ -486,6 +492,8 @@ impl Ty { "sfixed32" => Ty::Sfixed32, "sfixed64" => Ty::Sfixed64, "bool" => Ty::Bool, + "cow_str" => Ty::CowStr, + "cow_bytes" => Ty::CowBytes, "string" => Ty::String, "bytes" => Ty::Bytes(BytesTy::Vec), s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { @@ -523,6 +531,8 @@ impl Ty { Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", Ty::String => "string", + Ty::CowStr => "cow_str", + Ty::CowBytes => "cow_bytes", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -531,6 +541,8 @@ impl Ty { // TODO: rename to 'owned_type'. pub fn rust_type(&self) -> TokenStream { match self { + Ty::CowStr => quote!(::prost::alloc::borrow::Cow<'a, str>), + Ty::CowBytes => quote!(::prost::alloc::borrow::Cow<'a, [u8]>), Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), _ => self.rust_ref_type(), @@ -553,8 +565,8 @@ impl Ty { Ty::Sfixed32 => quote!(i32), Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), - Ty::String => quote!(&str), - Ty::Bytes(..) => quote!(&[u8]), + Ty::CowStr | Ty::String => quote!(&str), + Ty::CowBytes | Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } } @@ -568,7 +580,7 @@ impl Ty { /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::CowStr | Ty::String | Ty::CowBytes | Ty::Bytes(..)) } } @@ -610,6 +622,8 @@ pub enum DefaultValue { U64(u64), Bool(bool), String(String), + CowStr(std::borrow::Cow<'static, str>), + CowBytes(Vec), Bytes(Vec), Enumeration(TokenStream), Path(Path), @@ -774,6 +788,8 @@ impl DefaultValue { Ty::Bool => DefaultValue::Bool(false), Ty::String => DefaultValue::String(String::new()), + Ty::CowStr => DefaultValue::CowStr(std::borrow::Cow::Borrowed("")), + Ty::CowBytes => DefaultValue::CowBytes(Vec::new()), Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } @@ -785,6 +801,17 @@ impl DefaultValue { quote!(::prost::alloc::string::String::new()) } DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::CowStr(ref value) if value.is_empty() => { + quote!(::prost::alloc::borrow::Cow::Borrowed("")) + } + DefaultValue::CowStr(ref value) => quote!(#value.into()), + DefaultValue::CowBytes(ref value) if value.is_empty() => { + quote!(::core::default::Default::default()) + } + DefaultValue::CowBytes(ref value) => { + let lit = LitByteStr::new(value, Span::call_site()); + quote!(#lit.as_ref().into()) + } DefaultValue::Bytes(ref value) if value.is_empty() => { quote!(::core::default::Default::default()) } @@ -817,6 +844,11 @@ impl ToTokens for DefaultValue { DefaultValue::U64(value) => value.to_tokens(tokens), DefaultValue::Bool(value) => value.to_tokens(tokens), DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::CowStr(ref value) => value.to_tokens(tokens), + DefaultValue::CowBytes(ref value) => { + let byte_str = LitByteStr::new(value, Span::call_site()); + tokens.append_all(quote!(#byte_str as &[u8])); + } DefaultValue::Bytes(ref value) => { let byte_str = LitByteStr::new(value, Span::call_site()); tokens.append_all(quote!(#byte_str as &[u8]));