Skip to content

Commit

Permalink
fix(prost-build): Remove derived(Copy) on boxed fields
Browse files Browse the repository at this point in the history
  • Loading branch information
ldm0 committed Sep 13, 2024
1 parent 644c328 commit ddd6a33
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 54 deletions.
64 changes: 59 additions & 5 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
if self.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
Expand Down Expand Up @@ -615,10 +615,10 @@ impl<'a> CodeGenerator<'a> {
self.append_enum_attributes(&oneof_name);
self.push_indent();

let can_oneof_derive_copy = oneof.fields.iter().all(|field| {
self.message_graph
.can_field_derive_copy(fq_message_name, &field.descriptor)
});
let can_oneof_derive_copy = oneof
.fields
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, &field.descriptor));
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
Expand Down Expand Up @@ -1120,6 +1120,60 @@ impl<'a> CodeGenerator<'a> {
message_name,
)
}

/// Returns `true` if this message can automatically derive Copy trait.
fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
self.message_graph
.get_message(fq_message_name)
.unwrap()
.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Copy trait.
fn can_field_derive_copy(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
assert_eq!(".", &fq_message_name[..1]);

// repeated field cannot derive Copy
if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
// nested and boxed messages cannot derive Copy
if self
.message_graph
.is_nested(field.type_name(), fq_message_name)
|| self
.config
.boxed
.get_first_field(fq_message_name, field.name())
.is_some()
{
false
} else {
self.can_message_derive_copy(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Float
| Type::Double
| Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
)
}
}
}

/// Returns `true` if the repeated field type can be packed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Foo {
#[prost(string, tag="1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Foo {
#[prost(string, tag = "1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
Expand Down
53 changes: 6 additions & 47 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use petgraph::Graph;

use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto,
DescriptorProto, FileDescriptorProto,
};

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
Expand Down Expand Up @@ -74,6 +74,11 @@ impl MessageGraph {
}
}

/// Try get a message descriptor from current message graph
pub fn get_message(&self, message: &str) -> Option<&DescriptorProto> {
self.messages.get(message)
}

/// Returns true if message type `inner` is nested in message type `outer`.
pub fn is_nested(&self, outer: &str, inner: &str) -> bool {
let outer = match self.index.get(outer) {
Expand All @@ -87,50 +92,4 @@ impl MessageGraph {

has_path_connecting(&self.graph, outer, inner, None)
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_copy(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Copy trait.
pub fn can_field_derive_copy(
&self,
fq_message_name: &str,
field: &FieldDescriptorProto,
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.label() == Label::Repeated {
false
} else if field.r#type() == Type::Message {
if self.is_nested(field.type_name(), fq_message_name) {
false
} else {
self.can_message_derive_copy(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Float
| Type::Double
| Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
)
}
}
}

0 comments on commit ddd6a33

Please sign in to comment.