From 45567342df850e405906419117d1cfaabb673e88 Mon Sep 17 00:00:00 2001 From: Jon Doron Date: Mon, 2 Dec 2024 14:42:51 +0200 Subject: [PATCH] ProstBuild: CodeGen: Add support for adding Cow types Signed-off-by: Jon Doron --- prost-build/src/code_generator.rs | 109 +++++++++++++++++++++--------- prost-build/src/message_graph.rs | 33 +++++++-- 2 files changed, 106 insertions(+), 36 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index ceca5f467..c87011593 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -241,6 +241,9 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("pub struct "); self.buf.push_str(&to_upper_camel(&message_name)); + if self.message_graph.message_has_lifetime(&fq_message_name) { + self.buf.push_str("<'a>"); + } self.buf.push_str(" {\n"); self.depth += 1; @@ -406,13 +409,15 @@ impl CodeGenerator<'_> { let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self.boxed(&field.descriptor, fq_message_name, None); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let cowed = self.cowed(&field.descriptor, fq_message_name, None); + let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed); debug!( - " field: {:?}, type: {:?}, boxed: {}", + " field: {:?}, type: {:?}, boxed: {} cowed: {}", field.descriptor.name(), ty, - boxed + boxed, + cowed ); self.append_doc(fq_message_name, Some(field.descriptor.name())); @@ -424,10 +429,10 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field.descriptor); + let type_tag = self.field_type_tag(&field.descriptor, cowed); self.buf.push_str(&type_tag); - if type_ == Type::Bytes { + if !cowed && type_ == Type::Bytes { let bytes_type = self .config .bytes_type @@ -532,8 +537,10 @@ impl CodeGenerator<'_> { key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { - let key_ty = self.resolve_type(key, fq_message_name); - let value_ty = self.resolve_type(value, fq_message_name); + let key_cowed = self.cowed(key, fq_message_name, None); + let key_ty = self.resolve_type(key, fq_message_name, key_cowed); + let value_cowed = self.cowed(value, fq_message_name, None); + let value_ty = self.resolve_type(value, fq_message_name, value_cowed); debug!( " map field: {:?}, key type: {:?}, value type: {:?}", @@ -551,8 +558,8 @@ impl CodeGenerator<'_> { .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); - let key_tag = self.field_type_tag(key); - let value_tag = self.map_value_type_tag(value); + let key_tag = self.field_type_tag(key, key_cowed); + let value_tag = self.map_value_type_tag(value, value_cowed); self.buf.push_str(&format!( "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", @@ -597,9 +604,11 @@ impl CodeGenerator<'_> { self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( - "pub {}: ::core::option::Option<{}>,\n", + "pub {}: ::core::option::Option<{}{}>,\n", oneof.rust_name(), - type_name + type_name, + if self.message_graph + .message_has_lifetime(fq_message_name) { "<'a>" } else { "" }, )); } @@ -628,6 +637,9 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("pub enum "); self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); + if self.message_graph.message_has_lifetime(fq_message_name) { + self.buf.push_str("<'a>"); + } self.buf.push_str(" {\n"); self.path.push(2); @@ -637,8 +649,14 @@ impl CodeGenerator<'_> { self.append_doc(fq_message_name, Some(field.descriptor.name())); self.path.pop(); + let cowed = self.cowed( + &field.descriptor, + fq_message_name, + Some(oneof.descriptor.name()), + ); + self.push_indent(); - let ty_tag = self.field_type_tag(&field.descriptor); + let ty_tag = self.field_type_tag(&field.descriptor, cowed); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, @@ -647,7 +665,7 @@ impl CodeGenerator<'_> { self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed); let boxed = self.boxed( &field.descriptor, @@ -656,10 +674,11 @@ impl CodeGenerator<'_> { ); debug!( - " oneof: {:?}, type: {:?}, boxed: {}", + " oneof: {:?}, type: {:?}, boxed: {} cowed: {}", field.descriptor.name(), ty, - boxed + boxed, + cowed, ); if boxed { @@ -883,8 +902,8 @@ impl CodeGenerator<'_> { let name = method.name.take().unwrap(); let input_proto_type = method.input_type.take().unwrap(); let output_proto_type = method.output_type.take().unwrap(); - let input_type = self.resolve_ident(&input_proto_type); - let output_type = self.resolve_ident(&output_proto_type); + let input_type = self.resolve_ident(&input_proto_type).0; + let output_type = self.resolve_ident(&output_proto_type).0; let client_streaming = method.client_streaming(); let server_streaming = method.server_streaming(); @@ -947,7 +966,12 @@ impl CodeGenerator<'_> { self.buf.push_str("}\n"); } - fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { + fn resolve_type( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + cowed: bool, + ) -> String { match field.r#type() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), @@ -956,7 +980,13 @@ impl CodeGenerator<'_> { Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), Type::Bool => String::from("bool"), + Type::String if cowed => { + format!("{}::alloc::borrow::Cow<'a, str>", prost_path(self.config)) + } Type::String => format!("{}::alloc::string::String", prost_path(self.config)), + Type::Bytes if cowed => { + format!("{}::alloc::borrow::Cow<'a, [u8]>", prost_path(self.config)) + } Type::Bytes => self .config .bytes_type @@ -965,16 +995,28 @@ impl CodeGenerator<'_> { .unwrap_or_default() .rust_type() .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), + Type::Group | Type::Message => { + let (mut s, is_extern) = self.resolve_ident(field.type_name()); + if !is_extern + && cowed + && self + .message_graph + .field_has_lifetime(fq_message_name, field) + { + s.push_str("<'a>"); + } + s + } } } - fn resolve_ident(&self, pb_ident: &str) -> String { + /// Returns the identifier and a bool indicating if its an extern + fn resolve_ident(&self, pb_ident: &str) -> (String, bool) { // protoc should always give fully qualified identifiers. assert_eq!(".", &pb_ident[..1]); if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { - return proto_ident; + return (proto_ident, true); } let mut local_path = self @@ -1000,14 +1042,15 @@ impl CodeGenerator<'_> { ident_path.next(); } - local_path + let s = local_path .map(|_| "super".to_string()) .chain(ident_path.map(to_snake)) .chain(iter::once(to_upper_camel(ident_type))) - .join("::") + .join("::"); + (s, false) } - fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn field_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> { match field.r#type() { Type::Float => Cow::Borrowed("float"), Type::Double => Cow::Borrowed("double"), @@ -1022,24 +1065,26 @@ impl CodeGenerator<'_> { Type::Sfixed32 => Cow::Borrowed("sfixed32"), Type::Sfixed64 => Cow::Borrowed("sfixed64"), Type::Bool => Cow::Borrowed("bool"), + Type::String if cowed => Cow::Borrowed("cow_str"), Type::String => Cow::Borrowed("string"), + Type::Bytes if cowed => Cow::Borrowed("cow_bytes"), Type::Bytes => Cow::Borrowed("bytes"), Type::Group => Cow::Borrowed("group"), Type::Message => Cow::Borrowed("message"), Type::Enum => Cow::Owned(format!( "enumeration={:?}", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name()).0 )), } } - fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn map_value_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> { match field.r#type() { Type::Enum => Cow::Owned(format!( "enumeration({})", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name()).0 )), - _ => self.field_type_tag(field), + _ => self.field_type_tag(field, cowed), } } @@ -1111,7 +1156,10 @@ impl CodeGenerator<'_> { let fd_type = field.r#type(); // We only support Cow for Bytes and String - if !matches!(fd_type, Type::Bytes | Type::String) { + if !matches!( + fd_type, + Type::Message | Type::Group | Type::Bytes | Type::String + ) { return false; } @@ -1119,8 +1167,7 @@ impl CodeGenerator<'_> { None => Cow::Borrowed(fq_message_name), Some(ooname) => Cow::Owned(format!("{fq_message_name}.{ooname}")), }; - self - .config + self.config .cowed .get_first_field(&config_path, field.name()) .is_some() diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index a7cbe98f2..b8666ad29 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; @@ -157,20 +157,38 @@ impl MessageGraph { } } - pub fn message_has_lifetime(&self, fq_message_name: &str) -> bool { + fn message_has_lifetime_internal( + &self, + fq_message_name: &str, + visited: &mut HashSet, + ) -> bool { + visited.insert(fq_message_name.to_string()); assert_eq!(".", &fq_message_name[..1]); self.get_message(fq_message_name) .unwrap() .field .iter() - .any(|field| self.field_has_lifetime(fq_message_name, field)) + .any(|field| self.field_has_lifetime_internal(fq_message_name, field, visited)) } - pub fn field_has_lifetime(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { + pub fn message_has_lifetime(&self, fq_message_name: &str) -> bool { + let mut visited = Default::default(); + self.message_has_lifetime_internal(fq_message_name, &mut visited) + } + + fn field_has_lifetime_internal( + &self, + fq_message_name: &str, + field: &FieldDescriptorProto, + visited: &mut HashSet, + ) -> bool { assert_eq!(".", &fq_message_name[..1]); if field.r#type() == Type::Message { - self.message_has_lifetime(field.type_name()) + if visited.contains(field.type_name()) { + return false; + } + self.message_has_lifetime_internal(field.type_name(), visited) } else { matches!(field.r#type(), Type::Bytes | Type::String) && self @@ -179,4 +197,9 @@ impl MessageGraph { .is_some() } } + + pub fn field_has_lifetime(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { + let mut visited = Default::default(); + self.field_has_lifetime_internal(fq_message_name, field, &mut visited) + } }