diff --git a/conformance/conformance_rust.rs b/conformance/conformance_rust.rs index 122587b610b21..adfe8578ab830 100644 --- a/conformance/conformance_rust.rs +++ b/conformance/conformance_rust.rs @@ -48,7 +48,7 @@ fn read_request_from_stdin() -> Option { } fn write_response_to_stdout(resp: &ConformanceResponse) { - let bytes = resp.serialize(); + let bytes = resp.serialize().unwrap(); let len = bytes.len() as u32; let mut handle = io::stdout(); handle.write_all(&len.to_le_bytes()).unwrap(); @@ -76,7 +76,7 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse { let serialized = match message_type.as_bytes() { b"protobuf_test_messages.proto2.TestAllTypesProto2" => { if let Ok(msg) = TestAllTypesProto2::parse(bytes) { - msg.serialize() + msg.serialize().unwrap() } else { resp.set_parse_error("failed to parse bytes"); return resp; @@ -84,7 +84,7 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse { } b"protobuf_test_messages.proto3.TestAllTypesProto3" => { if let Ok(msg) = TestAllTypesProto3::parse(bytes) { - msg.serialize() + msg.serialize().unwrap() } else { resp.set_parse_error("failed to parse bytes"); return resp; @@ -92,7 +92,7 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse { } b"protobuf_test_messages.editions.TestAllTypesEdition2023" => { if let Ok(msg) = TestAllTypesEdition2023::parse(bytes) { - msg.serialize() + msg.serialize().unwrap() } else { resp.set_parse_error("failed to parse bytes"); return resp; @@ -100,7 +100,7 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse { } b"protobuf_test_messages.editions.proto2.TestAllTypesProto2" => { if let Ok(msg) = EditionsTestAllTypesProto2::parse(bytes) { - msg.serialize() + msg.serialize().unwrap() } else { resp.set_parse_error("failed to parse bytes"); return resp; @@ -108,7 +108,7 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse { } b"protobuf_test_messages.editions.proto3.TestAllTypesProto3" => { if let Ok(msg) = EditionsTestAllTypesProto3::parse(bytes) { - msg.serialize() + msg.serialize().unwrap() } else { resp.set_parse_error("failed to parse bytes"); return resp; diff --git a/rust/shared.rs b/rust/shared.rs index 9c839fa47a28f..1a0aeef91d1a3 100644 --- a/rust/shared.rs +++ b/rust/shared.rs @@ -31,7 +31,7 @@ pub mod __public { ProxiedInRepeated, Repeated, RepeatedIter, RepeatedMut, RepeatedView, }; pub use crate::string::ProtoStr; - pub use crate::ParseError; + pub use crate::{ParseError, SerializeError}; } pub use __public::*; @@ -61,7 +61,7 @@ mod proxied; mod repeated; mod string; -/// An error that happened during deserialization. +/// An error that happened during parsing. #[derive(Debug, Clone)] pub struct ParseError; @@ -72,3 +72,15 @@ impl fmt::Display for ParseError { write!(f, "Couldn't deserialize given bytes into a proto") } } + +/// An error that happened during serialization. +#[derive(Debug, Clone)] +pub struct SerializeError; + +impl std::error::Error for SerializeError {} + +impl fmt::Display for SerializeError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Couldn't serialize proto into bytes (depth too deep or missing required fields)") + } +} diff --git a/rust/test/cpp/interop/main.rs b/rust/test/cpp/interop/main.rs index 36998257174a4..812dab3a59bb4 100644 --- a/rust/test/cpp/interop/main.rs +++ b/rust/test/cpp/interop/main.rs @@ -76,7 +76,7 @@ fn deserialize_in_cpp() { let mut msg1 = TestAllTypes::new(); msg1.set_optional_int64(-1); msg1.set_optional_bytes(b"some cool data I guess"); - let data = msg1.serialize(); + let data = msg1.serialize().unwrap(); let msg2 = unsafe { TestAllTypes::__unstable_wrap_cpp_grant_permission_to_break(DeserializeTestAllTypes( @@ -93,7 +93,7 @@ fn deserialize_in_cpp_into_mut() { let mut msg1 = TestAllTypes::new(); msg1.set_optional_int64(-1); msg1.set_optional_bytes(b"some cool data I guess"); - let data = msg1.serialize(); + let data = msg1.serialize().unwrap(); let mut raw_msg = unsafe { DeserializeTestAllTypes((*data).as_ptr(), data.len()) }; let msg2 = TestAllTypesMut::__unstable_wrap_cpp_grant_permission_to_break(&mut raw_msg); @@ -111,7 +111,7 @@ fn deserialize_in_cpp_into_view() { let mut msg1 = TestAllTypes::new(); msg1.set_optional_int64(-1); msg1.set_optional_bytes(b"some cool data I guess"); - let data = msg1.serialize(); + let data = msg1.serialize().unwrap(); let raw_msg = unsafe { DeserializeTestAllTypes((*data).as_ptr(), data.len()) }; let msg2 = TestAllTypesView::__unstable_wrap_cpp_grant_permission_to_break(&raw_msg); @@ -131,7 +131,7 @@ fn smuggle_extension() { let msg1 = unsafe { TestAllExtensions::__unstable_wrap_cpp_grant_permission_to_break(NewWithExtension()) }; - let data = msg1.serialize(); + let data = msg1.serialize().unwrap(); let mut msg2 = TestAllExtensions::parse(&data).unwrap(); let bytes = unsafe { diff --git a/rust/test/shared/child_parent_test.rs b/rust/test/shared/child_parent_test.rs index 80b659dc730cf..1a269ec0521d4 100644 --- a/rust/test/shared/child_parent_test.rs +++ b/rust/test/shared/child_parent_test.rs @@ -18,10 +18,10 @@ fn test_canonical_types() { #[test] fn test_parent_serialization() { - assert_that!(*parent_rust_proto::Parent::new().serialize(), empty()); + assert_that!(*parent_rust_proto::Parent::new().serialize().unwrap(), empty()); } #[test] fn test_child_serialization() { - assert_that!(*child_rust_proto::Child::new().serialize(), empty()); + assert_that!(*child_rust_proto::Child::new().serialize().unwrap(), empty()); } diff --git a/rust/test/shared/serialization_test.rs b/rust/test/shared/serialization_test.rs index 436823bcbd288..4acd57ea5720b 100644 --- a/rust/test/shared/serialization_test.rs +++ b/rust/test/shared/serialization_test.rs @@ -12,13 +12,13 @@ use unittest_rust_proto::TestAllTypes; fn serialize_zero_length() { let mut msg = TestAllTypes::new(); - let serialized = msg.serialize(); + let serialized = msg.serialize().unwrap(); assert_that!(serialized.len(), eq(0)); - let serialized = msg.as_view().serialize(); + let serialized = msg.as_view().serialize().unwrap(); assert_that!(serialized.len(), eq(0)); - let serialized = msg.as_mut().serialize(); + let serialized = msg.as_mut().serialize().unwrap(); assert_that!(serialized.len(), eq(0)); } @@ -29,7 +29,7 @@ fn serialize_deserialize_message() { msg.set_optional_bool(true); msg.set_optional_bytes(b"serialize deserialize test"); - let serialized = msg.serialize(); + let serialized = msg.serialize().unwrap(); let msg2 = TestAllTypes::parse(&serialized).unwrap(); assert_that!(msg.optional_int64(), eq(msg2.optional_int64())); @@ -53,8 +53,8 @@ fn set_bytes_with_serialized_data() { msg.set_optional_int64(42); msg.set_optional_bool(true); let mut msg2 = TestAllTypes::new(); - msg2.set_optional_bytes(msg.serialize()); - assert_that!(msg2.optional_bytes(), eq(msg.serialize().as_ref())); + msg2.set_optional_bytes(msg.serialize().unwrap()); + assert_that!(msg2.optional_bytes(), eq(msg.serialize().unwrap().as_ref())); } #[test] @@ -64,7 +64,7 @@ fn deserialize_on_previously_allocated_message() { msg.set_optional_bool(true); msg.set_optional_bytes(b"serialize deserialize test"); - let serialized = msg.serialize(); + let serialized = msg.serialize().unwrap(); let mut msg2 = Box::new(TestAllTypes::new()); assert!(msg2.clear_and_parse(&serialized).is_ok()); diff --git a/rust/test/shared/utf8/utf8_test.rs b/rust/test/shared/utf8/utf8_test.rs index ed5da95960048..8b46cf86ddd4d 100644 --- a/rust/test/shared/utf8/utf8_test.rs +++ b/rust/test/shared/utf8/utf8_test.rs @@ -45,8 +45,7 @@ fn test_proto2() { assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES)); // No error on serialization - // TODO: Add test assertion once serialize becomes fallible. - let serialized_nonutf8 = msg.serialize(); + let serialized_nonutf8 = msg.serialize().expect("serialization should not fail"); // No error on parsing. let parsed_result = NoFeaturesProto2::parse(&serialized_nonutf8); @@ -64,8 +63,7 @@ fn test_proto3() { assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES)); // No error on serialization - // TODO: Add test assertion once serialize becomes fallible. - let serialized_nonutf8 = msg.serialize(); + let serialized_nonutf8 = msg.serialize().expect("serialization should not fail"); // Error on parsing. let parsed_result = NoFeaturesProto3::parse(&serialized_nonutf8); @@ -83,8 +81,7 @@ fn test_verify() { assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES)); // No error on serialization - // TODO: Add test assertion once serialize becomes fallible. - let serialized_nonutf8 = msg.serialize(); + let serialized_nonutf8 = msg.serialize().expect("serialization should not fail"); // Error on parsing. let parsed_result = Verify::parse(&serialized_nonutf8); diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 4c6e8a5fbad7e..d7443d03f0c32 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -61,7 +61,8 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: ctx.Emit({{"serialize_thunk", ThunkName(ctx, msg, "serialize")}}, R"rs( - unsafe { $serialize_thunk$(self.raw_msg()) } + //~ TODO: This should be fallible. + Ok(unsafe { $serialize_thunk$(self.raw_msg()) }) )rs"); return; @@ -74,11 +75,9 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) { let encoded = unsafe { $pbr$::wire::encode(self.raw_msg(), mini_table) }; - - //~ TODO: Currently serialize() on the Rust API is an - //~ infallible fn, so if upb signals an error here we can only panic. - let serialized = encoded.expect("serialize is not allowed to fail"); - serialized + //~ TODO: This discards the info we have about the reason + //~ of the failure, we should try to keep it instead. + encoded.map_err(|_| $pb$::SerializeError) )rs"); return; } @@ -933,7 +932,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.msg } - pub fn serialize(&self) -> $pbr$::SerializedData { + pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { $Msg::serialize$ } @@ -1009,7 +1008,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { self.inner } - pub fn serialize(&self) -> $pbr$::SerializedData { + pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { $pb$::ViewProxy::as_view(self).serialize() } @@ -1063,7 +1062,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { $raw_arena_getter_for_message$ - pub fn serialize(&self) -> $pbr$::SerializedData { + pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> { self.as_view().serialize() } #[deprecated = "Prefer Msg::parse(), or use the new name 'clear_and_parse' to parse into a pre-existing message."]