Skip to content

Commit

Permalink
Fix a python bug that UPB and Python C++ extension assume MessageSet …
Browse files Browse the repository at this point in the history
…extensions are ordered first

PiperOrigin-RevId: 698430014
  • Loading branch information
anandolee authored and copybara-github committed Nov 20, 2024
1 parent f5a2937 commit 3781f45
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
7 changes: 7 additions & 0 deletions python/google/protobuf/internal/message_set_extensions.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ message TestMessageSet {
}

message TestMessageSetExtension1 {
extend TestMessageSet {
optional TestExtension first_extension = 2534113;
}
extend TestMessageSet {
optional TestMessageSetExtension1 message_set_extension = 98418603;
}
Expand All @@ -36,6 +39,10 @@ message TestMessageSetExtension3 {
optional string text = 35;
}

message TestExtension {
optional string str = 1;
}

extend TestMessageSet {
optional TestMessageSetExtension3 message_set_extension3 = 98418655;
}
Expand Down
17 changes: 17 additions & 0 deletions python/google/protobuf/internal/text_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,23 @@ def testPrintMessageSet(self):
' text: \"bar\"\n'
'}\n')

def testMessageSetExtensionNotFirst(self):
desc = message_set_extensions_pb2.TestMessageSetExtension1.DESCRIPTOR
self.assertEqual('first_extension', desc.extensions[0].name)
self.assertEqual('message_set_extension', desc.extensions[1].name)
message = message_set_extensions_pb2.TestMessageSet()
ext = (
message_set_extensions_pb2.TestMessageSetExtension1.message_set_extension
)
message.Extensions[ext].i = 123
expected_str = (
'[google.protobuf.internal.TestMessageSetExtension1] {\n i: 123\n}\n'
)
self.CompareToGoldenText(text_format.MessageToString(message), expected_str)
parsed = message_set_extensions_pb2.TestMessageSet()
text_format.Parse(expected_str, parsed)
self.CompareToGoldenText(text_format.MessageToString(parsed), expected_str)

def testPrintMessageSetByFieldNumber(self):
out = text_format.TextWriter(False)
message = unittest_mset_pb2.TestMessageSetContainer()
Expand Down
25 changes: 17 additions & 8 deletions python/google/protobuf/pyext/extension_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,21 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
return 0;
}

static const FieldDescriptor* FindMessageSetExtension(
const Descriptor* message_descriptor) {
for (int i = 0; i < message_descriptor->extension_count(); i++) {
const FieldDescriptor* extension = message_descriptor->extension(i);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL &&
extension->message_type() == message_descriptor) {
return extension;
}
}
return nullptr;
}

PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
char* name;
Py_ssize_t name_size;
Expand All @@ -221,14 +236,8 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* arg) {
// Is is the name of a message set extension?
const Descriptor* message_descriptor =
pool->pool->FindMessageTypeByName(absl::string_view(name, name_size));
if (message_descriptor && message_descriptor->extension_count() > 0) {
const FieldDescriptor* extension = message_descriptor->extension(0);
if (extension->is_extension() &&
extension->containing_type()->options().message_set_wire_format() &&
extension->type() == FieldDescriptor::TYPE_MESSAGE &&
extension->label() == FieldDescriptor::LABEL_OPTIONAL) {
message_extension = extension;
}
if (message_descriptor) {
message_extension = FindMessageSetExtension(message_descriptor);
}
}
if (message_extension == nullptr) {
Expand Down
13 changes: 10 additions & 3 deletions upb/reflection/def_pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "upb/hash/str_table.h"
#include "upb/mem/alloc.h"
#include "upb/mem/arena.h"
#include "upb/reflection/def.h"
#include "upb/reflection/def_type.h"
#include "upb/reflection/file_def.h"
#include "upb/reflection/internal/def_builder.h"
Expand Down Expand Up @@ -236,9 +237,15 @@ const upb_FieldDef* upb_DefPool_FindExtensionByNameWithSize(
return _upb_DefType_Unpack(v, UPB_DEFTYPE_FIELD);
case UPB_DEFTYPE_MSG: {
const upb_MessageDef* m = _upb_DefType_Unpack(v, UPB_DEFTYPE_MSG);
return _upb_MessageDef_InMessageSet(m)
? upb_MessageDef_NestedExtension(m, 0)
: NULL;
if (_upb_MessageDef_InMessageSet(m)) {
for (int i = 0; i < upb_MessageDef_NestedExtensionCount(m); i++) {
const upb_FieldDef* ext = upb_MessageDef_NestedExtension(m, i);
if (upb_FieldDef_MessageSubDef(ext) == m) {
return ext;
}
}
}
return NULL;
}
default:
break;
Expand Down

0 comments on commit 3781f45

Please sign in to comment.