Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: derive Eq and Hash trait for messages where possible #1175

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,17 @@ impl CodeGenerator<'_> {
self.append_message_attributes(&fq_message_name);
self.push_indent();
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
"#[derive(Clone, {}PartialEq, {}{}::Message)]\n",
if self.message_graph.can_message_derive_copy(&fq_message_name) {
"Copy, "
} else {
""
},
if self.message_graph.can_message_derive_eq(&fq_message_name) {
"Eq, Hash, "
} else {
""
},
prost_path(self.config)
));
self.append_skip_debug(&fq_message_name);
Expand Down Expand Up @@ -619,9 +624,18 @@ impl CodeGenerator<'_> {
self.message_graph
.can_field_derive_copy(fq_message_name, &field.descriptor)
});
let can_oneof_derive_eq = oneof.fields.iter().all(|field| {
self.message_graph
.can_field_derive_eq(fq_message_name, &field.descriptor)
});
self.buf.push_str(&format!(
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
"#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n",
if can_oneof_derive_copy { "Copy, " } else { "" },
if can_oneof_derive_eq {
"Eq, Hash, "
} else {
""
},
prost_path(self.config)
));
self.append_skip_debug(fq_message_name);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Container {
#[prost(oneof="container::Data", tags="1, 2")]
pub data: ::core::option::Option<container::Data>,
}
/// Nested message and enum types in `Container`.
pub mod container {
#[derive(Clone, PartialEq, ::prost::Oneof)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
pub enum Data {
#[prost(message, tag="1")]
Foo(::prost::alloc::boxed::Box<super::Foo>),
#[prost(message, tag="2")]
Bar(super::Bar),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Foo {
#[prost(string, tag="1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag="1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Qux {
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Container {
#[prost(oneof = "container::Data", tags = "1, 2")]
pub data: ::core::option::Option<container::Data>,
}
/// Nested message and enum types in `Container`.
pub mod container {
#[derive(Clone, PartialEq, ::prost::Oneof)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
pub enum Data {
#[prost(message, tag = "1")]
Foo(::prost::alloc::boxed::Box<super::Foo>),
#[prost(message, tag = "2")]
Bar(super::Bar),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Foo {
#[prost(string, tag = "1")]
pub foo: ::prost::alloc::string::String,
}
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Bar {
#[prost(message, optional, boxed, tag = "1")]
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
}
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Qux {}
4 changes: 2 additions & 2 deletions prost-build/src/fixtures/helloworld/_expected_helloworld.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// This file is @generated by prost-build.
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Input)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Message {
#[prost(string, tag="1")]
pub say: ::prost::alloc::string::String,
}
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Output)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Response {
#[prost(string, tag="1")]
pub say: ::prost::alloc::string::String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// This file is @generated by prost-build.
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Input)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Message {
#[prost(string, tag = "1")]
pub say: ::prost::alloc::string::String,
}
#[derive(derive_builder::Builder)]
#[derive(custom_proto::Output)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Response {
#[prost(string, tag = "1")]
pub say: ::prost::alloc::string::String,
Expand Down
43 changes: 43 additions & 0 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,47 @@ impl MessageGraph {
)
}
}

/// Returns `true` if this message can automatically derive Eq trait.
pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);

let msg = self.messages.get(fq_message_name).unwrap();
msg.field
.iter()
.all(|field| self.can_field_derive_eq(fq_message_name, field))
}

/// Returns `true` if the type of this field allows deriving the Eq trait.
pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.r#type() == Type::Message {
if field.label() == Label::Repeated
|| self.is_nested(field.type_name(), fq_message_name)
{
false
} else {
self.can_message_derive_eq(field.type_name())
}
} else {
matches!(
field.r#type(),
Type::Int32
| Type::Int64
| Type::Uint32
| Type::Uint64
| Type::Sint32
| Type::Sint64
| Type::Fixed32
| Type::Fixed64
| Type::Sfixed32
| Type::Sfixed64
| Type::Bool
| Type::Enum
| Type::String
| Type::Bytes
)
}
}
}
2 changes: 1 addition & 1 deletion prost-types/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is @generated by prost-build.
/// The version number of protocol compiler.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Version {
#[prost(int32, optional, tag = "1")]
pub major: ::core::option::Option<i32>,
Expand Down
8 changes: 0 additions & 8 deletions prost-types/src/duration.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
use super::*;

#[cfg(feature = "std")]
impl std::hash::Hash for Duration {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.seconds.hash(state);
self.nanos.hash(state);
}
}

impl Duration {
/// Normalizes the duration to a canonical format.
///
Expand Down
24 changes: 12 additions & 12 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub mod descriptor_proto {
/// Range of reserved tag numbers. Reserved tag numbers may not be used by
/// fields or extension ranges in the same message. Reserved ranges may
/// not overlap.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct ReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -350,7 +350,7 @@ pub mod enum_descriptor_proto {
/// Note that this is distinct from DescriptorProto.ReservedRange in that it
/// is inclusive such that it can appropriately represent the entire int32
/// domain.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct EnumReservedRange {
/// Inclusive.
#[prost(int32, optional, tag = "1")]
Expand Down Expand Up @@ -961,7 +961,7 @@ pub mod uninterpreted_option {
/// extension (denoted with parentheses in options specs in .proto files).
/// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents
/// "foo.(bar.baz).qux".
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct NamePart {
#[prost(string, required, tag = "1")]
pub name_part: ::prost::alloc::string::String,
Expand Down Expand Up @@ -1022,7 +1022,7 @@ pub struct SourceCodeInfo {
}
/// Nested message and enum types in `SourceCodeInfo`.
pub mod source_code_info {
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Location {
/// Identifies which part of the FileDescriptorProto was defined at this
/// location.
Expand Down Expand Up @@ -1125,7 +1125,7 @@ pub struct GeneratedCodeInfo {
}
/// Nested message and enum types in `GeneratedCodeInfo`.
pub mod generated_code_info {
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Annotation {
/// Identifies the element in the original source .proto file. This field
/// is formatted the same as SourceCodeInfo.Location.path.
Expand Down Expand Up @@ -1238,7 +1238,7 @@ pub mod generated_code_info {
/// "value": "1.212s"
/// }
/// ```
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Any {
/// A URL/resource name that uniquely identifies the type of the serialized
/// protocol buffer message. This string must contain at least
Expand Down Expand Up @@ -1275,7 +1275,7 @@ pub struct Any {
}
/// `SourceContext` represents information about the source of a
/// protobuf element, like the file in which it is defined.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct SourceContext {
/// The path-qualified name of the .proto file that contained the associated
/// protobuf element. For example: `"google/protobuf/source_context.proto"`.
Expand Down Expand Up @@ -1531,7 +1531,7 @@ pub struct EnumValue {
}
/// A protocol buffer option, which can be attached to a message, field,
/// enumeration, etc.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Option {
/// The option's name. For protobuf built-in options (options defined in
/// descriptor.proto), this is the short name. For example, `"map_entry"`.
Expand Down Expand Up @@ -1741,7 +1741,7 @@ pub struct Method {
/// ...
/// }
/// ```
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Mixin {
/// The fully qualified name of the interface which is included.
#[prost(string, tag = "1")]
Expand Down Expand Up @@ -1815,7 +1815,7 @@ pub struct Mixin {
/// encoded in JSON format as "3s", while 3 seconds and 1 nanosecond should
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
/// microsecond should be expressed in JSON format as "3.000001s".
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Duration {
/// Signed seconds of the span of time. Must be from -315,576,000,000
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
Expand Down Expand Up @@ -2053,7 +2053,7 @@ pub struct Duration {
/// The implementation of any API method which has a FieldMask type field in the
/// request should verify the included field paths, and return an
/// `INVALID_ARGUMENT` error if any path is unmappable.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct FieldMask {
/// The set of field mask paths.
#[prost(string, repeated, tag = "1")]
Expand Down Expand Up @@ -2249,7 +2249,7 @@ impl NullValue {
/// [`strftime`](<https://docs.python.org/2/library/time.html#time.strftime>) with
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
pub struct Timestamp {
/// Represents seconds of UTC time since Unix epoch
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to
Expand Down
13 changes: 0 additions & 13 deletions prost-types/src/timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,6 @@ impl Name for Timestamp {
}
}

/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`.
/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`.
#[cfg(feature = "std")]
impl Eq for Timestamp {}

#[cfg(feature = "std")]
impl std::hash::Hash for Timestamp {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.seconds.hash(state);
self.nanos.hash(state);
}
}

#[cfg(feature = "std")]
impl From<std::time::SystemTime> for Timestamp {
fn from(system_time: std::time::SystemTime) -> Timestamp {
Expand Down
2 changes: 1 addition & 1 deletion tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() {
config.type_attribute("Foo.Custom.Attrs.AnotherEnum", "/// Oneof docs");
config.type_attribute(
"Foo.Custom.OneOfAttrs.Msg.field",
"#[derive(Eq, PartialOrd, Ord)]",
"#[derive(PartialOrd, Ord)]",
);
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.C", "/// The C docs");
config.field_attribute("Foo.Custom.Attrs.AnotherEnum.D", "/// The D docs");
Expand Down
2 changes: 1 addition & 1 deletion tests/single-include/src/outdir/outdir.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// This file is @generated by prost-build.
#[derive(Clone, PartialEq, ::prost::Message)]
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
pub struct OutdirRequest {
#[prost(string, tag = "1")]
pub query: ::prost::alloc::string::String,
Expand Down