Skip to content

Commit

Permalink
Merge pull request #10545 from deannagarcia/21.x
Browse files Browse the repository at this point in the history
Apply patch
  • Loading branch information
deannagarcia authored Sep 13, 2022
2 parents ea2f204 + cd0ee8f commit d88266c
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 35 deletions.
27 changes: 18 additions & 9 deletions src/google/protobuf/extension_set_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
internal::ParseContext* ctx) {
std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
uint32_t type_id;
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (!ctx->Done(&ptr)) {
uint32_t tag = static_cast<uint8_t>(*ptr++);
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
ExtensionInfo extension;
bool was_packed_on_wire;
if (!FindExtension(2, type_id, extendee, ctx, &extension,
Expand All @@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id != 0) {
if (state == State::kHasType) {
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
extendee, metadata, ctx);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
type_id = 0;
state = State::kDone;
} else {
std::string tmp;
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
ptr = ctx->ReadString(ptr, size, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
if (state == State::kNoTag) {
payload = std::move(tmp);
state = State::kHasPayload;
}
}
} else {
ptr = ReadTag(ptr - 1, &tag);
Expand Down
26 changes: 18 additions & 8 deletions src/google/protobuf/wire_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
// Parse a MessageSetItem
auto metadata = reflection->MutableInternalMetadata(msg);
enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

std::string payload;
uint32_t type_id = 0;
bool payload_read = false;
while (!ctx->Done(&ptr)) {
// We use 64 bit tags in order to allow typeid's that span the whole
// range of 32 bit numbers.
Expand All @@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
uint64_t tmp;
ptr = ParseBigVarint(ptr, &tmp);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
type_id = tmp;
if (payload_read) {
if (state == State::kNoTag) {
type_id = tmp;
state = State::kHasType;
} else if (state == State::kHasPayload) {
type_id = tmp;
const FieldDescriptor* field;
if (ctx->data().pool == nullptr) {
field = reflection->FindKnownExtensionByNumber(type_id);
Expand All @@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
tmp_ctx.EndedAtLimit());
}
type_id = 0;
state = State::kDone;
}
continue;
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
if (type_id == 0) {
if (state == State::kNoTag) {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->ReadString(ptr, size, &payload);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
payload_read = true;
} else {
state = State::kHasPayload;
} else if (state == State::kHasType) {
// We're now parsing the payload
const FieldDescriptor* field = nullptr;
if (descriptor->IsExtensionNumber(type_id)) {
Expand All @@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
ptr = WireFormat::_InternalParseAndMergeField(
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
field);
type_id = 0;
state = State::kDone;
} else {
int32_t size = ReadSize(&ptr);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
ptr = ctx->Skip(ptr, size);
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
}
} else {
// An unknown field in MessageSetItem.
Expand Down
27 changes: 18 additions & 9 deletions src/google/protobuf/wire_format_lite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
// we can parse it later.
std::string message_data;

enum class State { kNoTag, kHasType, kHasPayload, kDone };
State state = State::kNoTag;

while (true) {
const uint32_t tag = input->ReadTagNoLastTag();
if (tag == 0) return false;
Expand All @@ -1839,26 +1842,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
case WireFormatLite::kMessageSetTypeIdTag: {
uint32_t type_id;
if (!input->ReadVarint32(&type_id)) return false;
last_type_id = type_id;

if (!message_data.empty()) {
if (state == State::kNoTag) {
last_type_id = type_id;
state = State::kHasType;
} else if (state == State::kHasPayload) {
// We saw some message data before the type_id. Have to parse it
// now.
io::CodedInputStream sub_input(
reinterpret_cast<const uint8_t*>(message_data.data()),
static_cast<int>(message_data.size()));
sub_input.SetRecursionLimit(input->RecursionBudget());
if (!ms.ParseField(last_type_id, &sub_input)) {
if (!ms.ParseField(type_id, &sub_input)) {
return false;
}
message_data.clear();
state = State::kDone;
}

break;
}

case WireFormatLite::kMessageSetMessageTag: {
if (last_type_id == 0) {
if (state == State::kHasType) {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
state = State::kDone;
} else if (state == State::kNoTag) {
// We haven't seen a type_id yet. Append this data to message_data.
uint32_t length;
if (!input->ReadVarint32(&length)) return false;
Expand All @@ -1869,11 +1880,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
if (!input->ReadRaw(ptr, length)) return false;
state = State::kHasPayload;
} else {
// Already saw type_id, so we can parse this directly.
if (!ms.ParseField(last_type_id, input)) {
return false;
}
if (!ms.SkipField(tag, input)) return false;
}

break;
Expand Down
104 changes: 95 additions & 9 deletions src/google/protobuf/wire_format_unittest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -581,28 +581,54 @@ TEST(WireFormatTest, ParseMessageSet) {
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
}

TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
namespace {
std::string BuildMessageSetItemStart() {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(123);
// Build a MessageSet manually with its message content put before its
// type_id.
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
}
return data;
}
std::string BuildMessageSetItemEnd() {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
}
return data;
}
std::string BuildMessageSetTestExtension1(int value = 123) {
std::string data;
{
UNITTEST::TestMessageSetExtension1 message;
message.set_i(value);
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
// Write the message content first.
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
&coded_output);
coded_output.WriteVarint32(message.ByteSizeLong());
message.SerializeWithCachedSizes(&coded_output);
// Write the type id.
uint32_t type_id = message.GetDescriptor()->extension(0)->number();
}
return data;
}
std::string BuildMessageSetItemTypeId(int extension_number) {
std::string data;
{
io::StringOutputStream output_stream(&data);
io::CodedOutputStream coded_output(&output_stream);
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
type_id, &coded_output);
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
extension_number, &coded_output);
}
return data;
}
void ValidateTestMessageSet(const std::string& test_case,
const std::string& data) {
SCOPED_TRACE(test_case);
{
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
ASSERT_TRUE(message_set.ParseFromString(data));
Expand All @@ -612,6 +638,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
.GetExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());

// Make sure it does not contain anything else.
message_set.ClearExtension(
UNITTEST::TestMessageSetExtension1::message_set_extension);
EXPECT_EQ(message_set.SerializeAsString(), "");
}
{
// Test parse the message via Reflection.
Expand All @@ -627,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
UNITTEST::TestMessageSetExtension1::message_set_extension)
.i());
}
{
// Test parse the message via DynamicMessage.
DynamicMessageFactory factory;
std::unique_ptr<Message> msg(
factory
.GetPrototype(
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
->New());
msg->ParseFromString(data);
auto* reflection = msg->GetReflection();
std::vector<const FieldDescriptor*> fields;
reflection->ListFields(*msg, &fields);
ASSERT_EQ(fields.size(), 1);
const auto& sub = reflection->GetMessage(*msg, fields[0]);
reflection = sub.GetReflection();
EXPECT_EQ(123, reflection->GetInt32(
sub, sub.GetDescriptor()->FindFieldByName("i")));
}
}
} // namespace

TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string message = BuildMessageSetTestExtension1();

ValidateTestMessageSet("id + message", start + id + message + end);
ValidateTestMessageSet("message + id", start + message + id + end);
}

TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
std::string start = BuildMessageSetItemStart();
std::string end = BuildMessageSetItemEnd();
std::string id = BuildMessageSetItemTypeId(
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
std::string other_id = BuildMessageSetItemTypeId(123456);
std::string message = BuildMessageSetTestExtension1();
std::string other_message = BuildMessageSetTestExtension1(321);

// Double id
ValidateTestMessageSet("id + other_id + message",
start + id + other_id + message + end);
ValidateTestMessageSet("id + message + other_id",
start + id + message + other_id + end);
ValidateTestMessageSet("message + id + other_id",
start + message + id + other_id + end);
// Double message
ValidateTestMessageSet("id + message + other_message",
start + id + message + other_message + end);
ValidateTestMessageSet("message + id + other_message",
start + message + id + other_message + end);
ValidateTestMessageSet("message + other_message + id",
start + message + other_message + id + end);
}

void SerializeReverseOrder(
Expand Down

0 comments on commit d88266c

Please sign in to comment.