Skip to content

Commit f43ac49

Browse files
authored
Merge pull request #10542 from deannagarcia/3.18.x
Apply patch
2 parents 5b37c91 + 9efdf55 commit f43ac49

File tree

4 files changed

+150
-35
lines changed

4 files changed

+150
-35
lines changed

Diff for: src/google/protobuf/extension_set_inl.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,21 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
206206
const char* ptr, const Msg* extendee, internal::InternalMetadata* metadata,
207207
internal::ParseContext* ctx) {
208208
std::string payload;
209-
uint32_t type_id = 0;
210-
bool payload_read = false;
209+
uint32_t type_id;
210+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
211+
State state = State::kNoTag;
212+
211213
while (!ctx->Done(&ptr)) {
212214
uint32_t tag = static_cast<uint8_t>(*ptr++);
213215
if (tag == WireFormatLite::kMessageSetTypeIdTag) {
214216
uint64_t tmp;
215217
ptr = ParseBigVarint(ptr, &tmp);
216218
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
217-
type_id = tmp;
218-
if (payload_read) {
219+
if (state == State::kNoTag) {
220+
type_id = tmp;
221+
state = State::kHasType;
222+
} else if (state == State::kHasPayload) {
223+
type_id = tmp;
219224
ExtensionInfo extension;
220225
bool was_packed_on_wire;
221226
if (!FindExtension(2, type_id, extendee, ctx, &extension,
@@ -241,20 +246,24 @@ const char* ExtensionSet::ParseMessageSetItemTmpl(
241246
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
242247
tmp_ctx.EndedAtLimit());
243248
}
244-
type_id = 0;
249+
state = State::kDone;
245250
}
246251
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
247-
if (type_id != 0) {
252+
if (state == State::kHasType) {
248253
ptr = ParseFieldMaybeLazily(static_cast<uint64_t>(type_id) * 8 + 2, ptr,
249254
extendee, metadata, ctx);
250255
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
251-
type_id = 0;
256+
state = State::kDone;
252257
} else {
258+
std::string tmp;
253259
int32_t size = ReadSize(&ptr);
254260
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
255-
ptr = ctx->ReadString(ptr, size, &payload);
261+
ptr = ctx->ReadString(ptr, size, &tmp);
256262
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
257-
payload_read = true;
263+
if (state == State::kNoTag) {
264+
payload = std::move(tmp);
265+
state = State::kHasPayload;
266+
}
258267
}
259268
} else {
260269
ptr = ReadTag(ptr - 1, &tag);

Diff for: src/google/protobuf/wire_format.cc

+18-8
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,11 @@ struct WireFormat::MessageSetParser {
657657
const char* _InternalParse(const char* ptr, internal::ParseContext* ctx) {
658658
// Parse a MessageSetItem
659659
auto metadata = reflection->MutableInternalMetadata(msg);
660+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
661+
State state = State::kNoTag;
662+
660663
std::string payload;
661664
uint32_t type_id = 0;
662-
bool payload_read = false;
663665
while (!ctx->Done(&ptr)) {
664666
// We use 64 bit tags in order to allow typeid's that span the whole
665667
// range of 32 bit numbers.
@@ -668,8 +670,11 @@ struct WireFormat::MessageSetParser {
668670
uint64_t tmp;
669671
ptr = ParseBigVarint(ptr, &tmp);
670672
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
671-
type_id = tmp;
672-
if (payload_read) {
673+
if (state == State::kNoTag) {
674+
type_id = tmp;
675+
state = State::kHasType;
676+
} else if (state == State::kHasPayload) {
677+
type_id = tmp;
673678
const FieldDescriptor* field;
674679
if (ctx->data().pool == nullptr) {
675680
field = reflection->FindKnownExtensionByNumber(type_id);
@@ -696,17 +701,17 @@ struct WireFormat::MessageSetParser {
696701
GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
697702
tmp_ctx.EndedAtLimit());
698703
}
699-
type_id = 0;
704+
state = State::kDone;
700705
}
701706
continue;
702707
} else if (tag == WireFormatLite::kMessageSetMessageTag) {
703-
if (type_id == 0) {
708+
if (state == State::kNoTag) {
704709
int32_t size = ReadSize(&ptr);
705710
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
706711
ptr = ctx->ReadString(ptr, size, &payload);
707712
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
708-
payload_read = true;
709-
} else {
713+
state = State::kHasPayload;
714+
} else if (state == State::kHasType) {
710715
// We're now parsing the payload
711716
const FieldDescriptor* field = nullptr;
712717
if (descriptor->IsExtensionNumber(type_id)) {
@@ -720,7 +725,12 @@ struct WireFormat::MessageSetParser {
720725
ptr = WireFormat::_InternalParseAndMergeField(
721726
msg, ptr, ctx, static_cast<uint64_t>(type_id) * 8 + 2, reflection,
722727
field);
723-
type_id = 0;
728+
state = State::kDone;
729+
} else {
730+
int32_t size = ReadSize(&ptr);
731+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
732+
ptr = ctx->Skip(ptr, size);
733+
GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
724734
}
725735
} else {
726736
// An unknown field in MessageSetItem.

Diff for: src/google/protobuf/wire_format_lite.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18451845
// we can parse it later.
18461846
std::string message_data;
18471847

1848+
enum class State { kNoTag, kHasType, kHasPayload, kDone };
1849+
State state = State::kNoTag;
1850+
18481851
while (true) {
18491852
const uint32_t tag = input->ReadTagNoLastTag();
18501853
if (tag == 0) return false;
@@ -1853,26 +1856,34 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18531856
case WireFormatLite::kMessageSetTypeIdTag: {
18541857
uint32_t type_id;
18551858
if (!input->ReadVarint32(&type_id)) return false;
1856-
last_type_id = type_id;
1857-
1858-
if (!message_data.empty()) {
1859+
if (state == State::kNoTag) {
1860+
last_type_id = type_id;
1861+
state = State::kHasType;
1862+
} else if (state == State::kHasPayload) {
18591863
// We saw some message data before the type_id. Have to parse it
18601864
// now.
18611865
io::CodedInputStream sub_input(
18621866
reinterpret_cast<const uint8_t*>(message_data.data()),
18631867
static_cast<int>(message_data.size()));
18641868
sub_input.SetRecursionLimit(input->RecursionBudget());
1865-
if (!ms.ParseField(last_type_id, &sub_input)) {
1869+
if (!ms.ParseField(type_id, &sub_input)) {
18661870
return false;
18671871
}
18681872
message_data.clear();
1873+
state = State::kDone;
18691874
}
18701875

18711876
break;
18721877
}
18731878

18741879
case WireFormatLite::kMessageSetMessageTag: {
1875-
if (last_type_id == 0) {
1880+
if (state == State::kHasType) {
1881+
// Already saw type_id, so we can parse this directly.
1882+
if (!ms.ParseField(last_type_id, input)) {
1883+
return false;
1884+
}
1885+
state = State::kDone;
1886+
} else if (state == State::kNoTag) {
18761887
// We haven't seen a type_id yet. Append this data to message_data.
18771888
uint32_t length;
18781889
if (!input->ReadVarint32(&length)) return false;
@@ -1883,11 +1894,9 @@ bool ParseMessageSetItemImpl(io::CodedInputStream* input, MS ms) {
18831894
auto ptr = reinterpret_cast<uint8_t*>(&message_data[0]);
18841895
ptr = io::CodedOutputStream::WriteVarint32ToArray(length, ptr);
18851896
if (!input->ReadRaw(ptr, length)) return false;
1897+
state = State::kHasPayload;
18861898
} else {
1887-
// Already saw type_id, so we can parse this directly.
1888-
if (!ms.ParseField(last_type_id, input)) {
1889-
return false;
1890-
}
1899+
if (!ms.SkipField(tag, input)) return false;
18911900
}
18921901

18931902
break;

Diff for: src/google/protobuf/wire_format_unittest.inc

+96-9
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include <google/protobuf/stubs/casts.h>
5050
#include <google/protobuf/stubs/strutil.h>
5151
#include <google/protobuf/stubs/stl_util.h>
52+
#include <google/protobuf/dynamic_message.h>
5253

5354
// clang-format off
5455
#include <google/protobuf/port_def.inc>
@@ -580,28 +581,54 @@ TEST(WireFormatTest, ParseMessageSet) {
580581
EXPECT_EQ(message_set.DebugString(), dynamic_message_set.DebugString());
581582
}
582583

583-
TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
584+
namespace {
585+
std::string BuildMessageSetItemStart() {
584586
std::string data;
585587
{
586-
UNITTEST::TestMessageSetExtension1 message;
587-
message.set_i(123);
588-
// Build a MessageSet manually with its message content put before its
589-
// type_id.
590588
io::StringOutputStream output_stream(&data);
591589
io::CodedOutputStream coded_output(&output_stream);
592590
coded_output.WriteTag(WireFormatLite::kMessageSetItemStartTag);
591+
}
592+
return data;
593+
}
594+
std::string BuildMessageSetItemEnd() {
595+
std::string data;
596+
{
597+
io::StringOutputStream output_stream(&data);
598+
io::CodedOutputStream coded_output(&output_stream);
599+
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
600+
}
601+
return data;
602+
}
603+
std::string BuildMessageSetTestExtension1(int value = 123) {
604+
std::string data;
605+
{
606+
UNITTEST::TestMessageSetExtension1 message;
607+
message.set_i(value);
608+
io::StringOutputStream output_stream(&data);
609+
io::CodedOutputStream coded_output(&output_stream);
593610
// Write the message content first.
594611
WireFormatLite::WriteTag(WireFormatLite::kMessageSetMessageNumber,
595612
WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
596613
&coded_output);
597614
coded_output.WriteVarint32(message.ByteSizeLong());
598615
message.SerializeWithCachedSizes(&coded_output);
599-
// Write the type id.
600-
uint32 type_id = message.GetDescriptor()->extension(0)->number();
616+
}
617+
return data;
618+
}
619+
std::string BuildMessageSetItemTypeId(int extension_number) {
620+
std::string data;
621+
{
622+
io::StringOutputStream output_stream(&data);
623+
io::CodedOutputStream coded_output(&output_stream);
601624
WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
602-
type_id, &coded_output);
603-
coded_output.WriteTag(WireFormatLite::kMessageSetItemEndTag);
625+
extension_number, &coded_output);
604626
}
627+
return data;
628+
}
629+
void ValidateTestMessageSet(const std::string& test_case,
630+
const std::string& data) {
631+
SCOPED_TRACE(test_case);
605632
{
606633
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet message_set;
607634
ASSERT_TRUE(message_set.ParseFromString(data));
@@ -611,6 +638,11 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
611638
.GetExtension(
612639
UNITTEST::TestMessageSetExtension1::message_set_extension)
613640
.i());
641+
642+
// Make sure it does not contain anything else.
643+
message_set.ClearExtension(
644+
UNITTEST::TestMessageSetExtension1::message_set_extension);
645+
EXPECT_EQ(message_set.SerializeAsString(), "");
614646
}
615647
{
616648
// Test parse the message via Reflection.
@@ -626,6 +658,61 @@ TEST(WireFormatTest, ParseMessageSetWithReverseTagOrder) {
626658
UNITTEST::TestMessageSetExtension1::message_set_extension)
627659
.i());
628660
}
661+
{
662+
// Test parse the message via DynamicMessage.
663+
DynamicMessageFactory factory;
664+
std::unique_ptr<Message> msg(
665+
factory
666+
.GetPrototype(
667+
PROTO2_WIREFORMAT_UNITTEST::TestMessageSet::descriptor())
668+
->New());
669+
msg->ParseFromString(data);
670+
auto* reflection = msg->GetReflection();
671+
std::vector<const FieldDescriptor*> fields;
672+
reflection->ListFields(*msg, &fields);
673+
ASSERT_EQ(fields.size(), 1);
674+
const auto& sub = reflection->GetMessage(*msg, fields[0]);
675+
reflection = sub.GetReflection();
676+
EXPECT_EQ(123, reflection->GetInt32(
677+
sub, sub.GetDescriptor()->FindFieldByName("i")));
678+
}
679+
}
680+
} // namespace
681+
682+
TEST(WireFormatTest, ParseMessageSetWithAnyTagOrder) {
683+
std::string start = BuildMessageSetItemStart();
684+
std::string end = BuildMessageSetItemEnd();
685+
std::string id = BuildMessageSetItemTypeId(
686+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
687+
std::string message = BuildMessageSetTestExtension1();
688+
689+
ValidateTestMessageSet("id + message", start + id + message + end);
690+
ValidateTestMessageSet("message + id", start + message + id + end);
691+
}
692+
693+
TEST(WireFormatTest, ParseMessageSetWithDuplicateTags) {
694+
std::string start = BuildMessageSetItemStart();
695+
std::string end = BuildMessageSetItemEnd();
696+
std::string id = BuildMessageSetItemTypeId(
697+
UNITTEST::TestMessageSetExtension1::descriptor()->extension(0)->number());
698+
std::string other_id = BuildMessageSetItemTypeId(123456);
699+
std::string message = BuildMessageSetTestExtension1();
700+
std::string other_message = BuildMessageSetTestExtension1(321);
701+
702+
// Double id
703+
ValidateTestMessageSet("id + other_id + message",
704+
start + id + other_id + message + end);
705+
ValidateTestMessageSet("id + message + other_id",
706+
start + id + message + other_id + end);
707+
ValidateTestMessageSet("message + id + other_id",
708+
start + message + id + other_id + end);
709+
// Double message
710+
ValidateTestMessageSet("id + message + other_message",
711+
start + id + message + other_message + end);
712+
ValidateTestMessageSet("message + id + other_message",
713+
start + message + id + other_message + end);
714+
ValidateTestMessageSet("message + other_message + id",
715+
start + message + other_message + id + end);
629716
}
630717

631718
void SerializeReverseOrder(

0 commit comments

Comments
 (0)