Skip to content

Commit

Permalink
Make serialization correctly be fallible in the Rust Protobuf API.
Browse files Browse the repository at this point in the history
This doesn't _actually_ make the C++ Kernel path ever fail yet, but now that the API is a Result<SerializedData, SerializeError> it can be fixed as an implementation detail.

PiperOrigin-RevId: 638329136
  • Loading branch information
protobuf-github-bot authored and copybara-github committed May 29, 2024
1 parent 18da465 commit fdc7f65
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 36 deletions.
12 changes: 6 additions & 6 deletions conformance/conformance_rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn read_request_from_stdin() -> Option<ConformanceRequest> {
}

fn write_response_to_stdout(resp: &ConformanceResponse) {
let bytes = resp.serialize();
let bytes = resp.serialize().unwrap();
let len = bytes.len() as u32;
let mut handle = io::stdout();
handle.write_all(&len.to_le_bytes()).unwrap();
Expand Down Expand Up @@ -76,39 +76,39 @@ fn do_test(req: &ConformanceRequest) -> ConformanceResponse {
let serialized = match message_type.as_bytes() {
b"protobuf_test_messages.proto2.TestAllTypesProto2" => {
if let Ok(msg) = TestAllTypesProto2::parse(bytes) {
msg.serialize()
msg.serialize().unwrap()
} else {
resp.set_parse_error("failed to parse bytes");
return resp;
}
}
b"protobuf_test_messages.proto3.TestAllTypesProto3" => {
if let Ok(msg) = TestAllTypesProto3::parse(bytes) {
msg.serialize()
msg.serialize().unwrap()
} else {
resp.set_parse_error("failed to parse bytes");
return resp;
}
}
b"protobuf_test_messages.editions.TestAllTypesEdition2023" => {
if let Ok(msg) = TestAllTypesEdition2023::parse(bytes) {
msg.serialize()
msg.serialize().unwrap()
} else {
resp.set_parse_error("failed to parse bytes");
return resp;
}
}
b"protobuf_test_messages.editions.proto2.TestAllTypesProto2" => {
if let Ok(msg) = EditionsTestAllTypesProto2::parse(bytes) {
msg.serialize()
msg.serialize().unwrap()
} else {
resp.set_parse_error("failed to parse bytes");
return resp;
}
}
b"protobuf_test_messages.editions.proto3.TestAllTypesProto3" => {
if let Ok(msg) = EditionsTestAllTypesProto3::parse(bytes) {
msg.serialize()
msg.serialize().unwrap()
} else {
resp.set_parse_error("failed to parse bytes");
return resp;
Expand Down
16 changes: 14 additions & 2 deletions rust/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub mod __public {
ProxiedInRepeated, Repeated, RepeatedIter, RepeatedMut, RepeatedView,
};
pub use crate::string::ProtoStr;
pub use crate::ParseError;
pub use crate::{ParseError, SerializeError};
}
pub use __public::*;

Expand Down Expand Up @@ -61,7 +61,7 @@ mod proxied;
mod repeated;
mod string;

/// An error that happened during deserialization.
/// An error that happened during parsing.
#[derive(Debug, Clone)]
pub struct ParseError;

Expand All @@ -72,3 +72,15 @@ impl fmt::Display for ParseError {
write!(f, "Couldn't deserialize given bytes into a proto")
}
}

/// An error that happened during serialization.
#[derive(Debug, Clone)]
pub struct SerializeError;

impl std::error::Error for SerializeError {}

impl fmt::Display for SerializeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Couldn't serialize proto into bytes (depth too deep or missing required fields)")
}
}
8 changes: 4 additions & 4 deletions rust/test/cpp/interop/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn deserialize_in_cpp() {
let mut msg1 = TestAllTypes::new();
msg1.set_optional_int64(-1);
msg1.set_optional_bytes(b"some cool data I guess");
let data = msg1.serialize();
let data = msg1.serialize().unwrap();

let msg2 = unsafe {
TestAllTypes::__unstable_wrap_cpp_grant_permission_to_break(DeserializeTestAllTypes(
Expand All @@ -93,7 +93,7 @@ fn deserialize_in_cpp_into_mut() {
let mut msg1 = TestAllTypes::new();
msg1.set_optional_int64(-1);
msg1.set_optional_bytes(b"some cool data I guess");
let data = msg1.serialize();
let data = msg1.serialize().unwrap();

let mut raw_msg = unsafe { DeserializeTestAllTypes((*data).as_ptr(), data.len()) };
let msg2 = TestAllTypesMut::__unstable_wrap_cpp_grant_permission_to_break(&mut raw_msg);
Expand All @@ -111,7 +111,7 @@ fn deserialize_in_cpp_into_view() {
let mut msg1 = TestAllTypes::new();
msg1.set_optional_int64(-1);
msg1.set_optional_bytes(b"some cool data I guess");
let data = msg1.serialize();
let data = msg1.serialize().unwrap();

let raw_msg = unsafe { DeserializeTestAllTypes((*data).as_ptr(), data.len()) };
let msg2 = TestAllTypesView::__unstable_wrap_cpp_grant_permission_to_break(&raw_msg);
Expand All @@ -131,7 +131,7 @@ fn smuggle_extension() {
let msg1 = unsafe {
TestAllExtensions::__unstable_wrap_cpp_grant_permission_to_break(NewWithExtension())
};
let data = msg1.serialize();
let data = msg1.serialize().unwrap();

let mut msg2 = TestAllExtensions::parse(&data).unwrap();
let bytes = unsafe {
Expand Down
4 changes: 2 additions & 2 deletions rust/test/shared/child_parent_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ fn test_canonical_types() {

#[test]
fn test_parent_serialization() {
assert_that!(*parent_rust_proto::Parent::new().serialize(), empty());
assert_that!(*parent_rust_proto::Parent::new().serialize().unwrap(), empty());
}

#[test]
fn test_child_serialization() {
assert_that!(*child_rust_proto::Child::new().serialize(), empty());
assert_that!(*child_rust_proto::Child::new().serialize().unwrap(), empty());
}
14 changes: 7 additions & 7 deletions rust/test/shared/serialization_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ use unittest_rust_proto::TestAllTypes;
fn serialize_zero_length() {
let mut msg = TestAllTypes::new();

let serialized = msg.serialize();
let serialized = msg.serialize().unwrap();
assert_that!(serialized.len(), eq(0));

let serialized = msg.as_view().serialize();
let serialized = msg.as_view().serialize().unwrap();
assert_that!(serialized.len(), eq(0));

let serialized = msg.as_mut().serialize();
let serialized = msg.as_mut().serialize().unwrap();
assert_that!(serialized.len(), eq(0));
}

Expand All @@ -29,7 +29,7 @@ fn serialize_deserialize_message() {
msg.set_optional_bool(true);
msg.set_optional_bytes(b"serialize deserialize test");

let serialized = msg.serialize();
let serialized = msg.serialize().unwrap();

let msg2 = TestAllTypes::parse(&serialized).unwrap();
assert_that!(msg.optional_int64(), eq(msg2.optional_int64()));
Expand All @@ -53,8 +53,8 @@ fn set_bytes_with_serialized_data() {
msg.set_optional_int64(42);
msg.set_optional_bool(true);
let mut msg2 = TestAllTypes::new();
msg2.set_optional_bytes(msg.serialize());
assert_that!(msg2.optional_bytes(), eq(msg.serialize().as_ref()));
msg2.set_optional_bytes(msg.serialize().unwrap());
assert_that!(msg2.optional_bytes(), eq(msg.serialize().unwrap().as_ref()));
}

#[test]
Expand All @@ -64,7 +64,7 @@ fn deserialize_on_previously_allocated_message() {
msg.set_optional_bool(true);
msg.set_optional_bytes(b"serialize deserialize test");

let serialized = msg.serialize();
let serialized = msg.serialize().unwrap();

let mut msg2 = Box::new(TestAllTypes::new());
assert!(msg2.clear_and_parse(&serialized).is_ok());
Expand Down
9 changes: 3 additions & 6 deletions rust/test/shared/utf8/utf8_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ fn test_proto2() {
assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES));

// No error on serialization
// TODO: Add test assertion once serialize becomes fallible.
let serialized_nonutf8 = msg.serialize();
let serialized_nonutf8 = msg.serialize().expect("serialization should not fail");

// No error on parsing.
let parsed_result = NoFeaturesProto2::parse(&serialized_nonutf8);
Expand All @@ -64,8 +63,7 @@ fn test_proto3() {
assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES));

// No error on serialization
// TODO: Add test assertion once serialize becomes fallible.
let serialized_nonutf8 = msg.serialize();
let serialized_nonutf8 = msg.serialize().expect("serialization should not fail");

// Error on parsing.
let parsed_result = NoFeaturesProto3::parse(&serialized_nonutf8);
Expand All @@ -83,8 +81,7 @@ fn test_verify() {
assert_that!(msg.my_field().as_bytes(), eq(NON_UTF8_BYTES));

// No error on serialization
// TODO: Add test assertion once serialize becomes fallible.
let serialized_nonutf8 = msg.serialize();
let serialized_nonutf8 = msg.serialize().expect("serialization should not fail");

// Error on parsing.
let parsed_result = Verify::parse(&serialized_nonutf8);
Expand Down
17 changes: 8 additions & 9 deletions src/google/protobuf/compiler/rust/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
ctx.Emit({{"serialize_thunk", ThunkName(ctx, msg, "serialize")}}, R"rs(
unsafe { $serialize_thunk$(self.raw_msg()) }
//~ TODO: This should be fallible.
Ok(unsafe { $serialize_thunk$(self.raw_msg()) })
)rs");
return;

Expand All @@ -74,11 +75,9 @@ void MessageSerialize(Context& ctx, const Descriptor& msg) {
let encoded = unsafe {
$pbr$::wire::encode(self.raw_msg(), mini_table)
};
//~ TODO: Currently serialize() on the Rust API is an
//~ infallible fn, so if upb signals an error here we can only panic.
let serialized = encoded.expect("serialize is not allowed to fail");
serialized
//~ TODO: This discards the info we have about the reason
//~ of the failure, we should try to keep it instead.
encoded.map_err(|_| $pb$::SerializeError)
)rs");
return;
}
Expand Down Expand Up @@ -933,7 +932,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
self.msg
}
pub fn serialize(&self) -> $pbr$::SerializedData {
pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> {
$Msg::serialize$
}
Expand Down Expand Up @@ -1009,7 +1008,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
self.inner
}
pub fn serialize(&self) -> $pbr$::SerializedData {
pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> {
$pb$::ViewProxy::as_view(self).serialize()
}
Expand Down Expand Up @@ -1063,7 +1062,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
$raw_arena_getter_for_message$
pub fn serialize(&self) -> $pbr$::SerializedData {
pub fn serialize(&self) -> Result<$pbr$::SerializedData, $pb$::SerializeError> {
self.as_view().serialize()
}
#[deprecated = "Prefer Msg::parse(), or use the new name 'clear_and_parse' to parse into a pre-existing message."]
Expand Down

0 comments on commit fdc7f65

Please sign in to comment.