Skip to content

Commit

Permalink
Breaking Change: Python setdefault behavior change for map field.
Browse files Browse the repository at this point in the history
  -setdefault will be similar with dict for ScalarMap. But both key and value must be set.
  -setdefault will be rejected for MessageMap.

PiperOrigin-RevId: 695768629
  • Loading branch information
anandolee authored and copybara-github committed Nov 12, 2024
1 parent 54dc8c7 commit 81da6b9
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 2 deletions.
13 changes: 13 additions & 0 deletions python/google/protobuf/internal/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,13 @@ def __iter__(self) -> Iterator[_K]:
def __repr__(self) -> str:
return repr(self._values)

def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
if value == None:
raise ValueError('The value for scalar map setdefault must be set.')
if key not in self._values:
self.__setitem__(key, value)
return self[key]

def MergeFrom(self, other: 'ScalarMap[_K, _V]') -> None:
self._values.update(other._values)
self._message_listener.Modified()
Expand Down Expand Up @@ -526,6 +533,12 @@ def __iter__(self) -> Iterator[_K]:
def __repr__(self) -> str:
return repr(self._values)

def setdefault(self, key: _K, value: Optional[_V] = None) -> _V:
raise NotImplementedError(
'Set message map value directly is not supported, call'
' my_map[key].foo = 5'
)

def MergeFrom(self, other: 'MessageMap[_K, _V]') -> None:
# pylint: disable=protected-access
for key in other._values:
Expand Down
28 changes: 28 additions & 0 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,13 +1900,41 @@ def testScalarMapComparison(self):

self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32)

def testScalarMapSetdefault(self):
msg = map_unittest_pb2.TestMap()
value = msg.map_int32_int32.setdefault(123, 888)
self.assertEqual(value, 888)
self.assertEqual(msg.map_int32_int32[123], 888)
value = msg.map_int32_int32.setdefault(123, 777)
self.assertEqual(value, 888)

with self.assertRaises(ValueError):
value = msg.map_int32_int32.setdefault(1001)
self.assertNotIn(1001, msg.map_int32_int32)
with self.assertRaises(TypeError):
value = msg.map_int32_int32.setdefault()
with self.assertRaises(TypeError):
value = msg.map_int32_int32.setdefault(1, 2, 3)
with self.assertRaises(TypeError):
value = msg.map_int32_int32.setdefault("1", 2)
with self.assertRaises(TypeError):
value = msg.map_int32_int32.setdefault(1, "2")

def testMessageMapComparison(self):
msg1 = map_unittest_pb2.TestMap()
msg2 = map_unittest_pb2.TestMap()

self.assertEqual(msg1.map_int32_foreign_message,
msg2.map_int32_foreign_message)

def testMessageMapSetdefault(self):
msg = map_unittest_pb2.TestMap()
msg.map_int32_foreign_message[123].c = 888
with self.assertRaises(NotImplementedError):
msg.map_int32_foreign_message.setdefault(
1, msg.map_int32_foreign_message[123]
)

def testMapGet(self):
# Need to test that get() properly returns the default, even though the dict
# has defaultdict-like semantics.
Expand Down
39 changes: 39 additions & 0 deletions python/google/protobuf/pyext/map_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,35 @@ int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
}
}

static PyObject* ScalarMapSetdefault(PyObject* self, PyObject* args) {
PyObject* key = nullptr;
PyObject* default_value = Py_None;

if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) {
return nullptr;
}

if (default_value == Py_None) {
PyErr_Format(PyExc_ValueError,
"The value for scalar map setdefault must be set.");
return nullptr;
}

ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
if (is_present == nullptr) {
return nullptr;
}
if (PyObject_IsTrue(is_present.get())) {
return MapReflectionFriend::ScalarMapGetItem(self, key);
}

if (MapReflectionFriend::ScalarMapSetItem(self, key, default_value) < 0) {
return nullptr;
}
Py_INCREF(default_value);
return default_value;
}

static PyObject* ScalarMapGet(PyObject* self, PyObject* args,
PyObject* kwargs) {
static const char* kwlist[] = {"key", "default", nullptr};
Expand Down Expand Up @@ -512,6 +541,8 @@ static PyMethodDef ScalarMapMethods[] = {
"Tests whether a key is a member of the map."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)ScalarMapSetdefault, METH_VARARGS,
"If the key does not exist, insert the key, with the specified value"},
{"get", (PyCFunction)ScalarMapGet, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"GetEntryClass", (PyCFunction)GetEntryClass, METH_NOARGS,
Expand Down Expand Up @@ -685,6 +716,12 @@ PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
return PyObject_Repr(dict.get());
}

static PyObject* MessageMapSetdefault(PyObject* self, PyObject* args) {
PyErr_Format(PyExc_NotImplementedError,
"Set message map value directly is not supported.");
return nullptr;
}

PyObject* MessageMapGet(PyObject* self, PyObject* args, PyObject* kwargs) {
static const char* kwlist[] = {"key", "default", nullptr};
PyObject* key;
Expand Down Expand Up @@ -729,6 +766,8 @@ static PyMethodDef MessageMapMethods[] = {
"Tests whether the map contains this element."},
{"clear", (PyCFunction)Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)MessageMapSetdefault, METH_VARARGS,
"setdefault is disallowed in MessageMap."},
{"get", (PyCFunction)MessageMapGet, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
Expand Down
51 changes: 49 additions & 2 deletions python/map.c
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,48 @@ static PyObject* PyUpb_MapContainer_Clear(PyObject* _self, PyObject* key) {
Py_RETURN_NONE;
}

static PyObject* PyUpb_ScalarMapContainer_Setdefault(PyObject* _self,
PyObject* args) {
PyObject* key;
PyObject* default_value = Py_None;

if (!PyArg_UnpackTuple(args, "setdefault", 1, 2, &key, &default_value)) {
return NULL;
}

if (default_value == Py_None) {
PyErr_Format(PyExc_ValueError,
"The value for scalar map setdefault must be set.");
return NULL;
}

PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self;
upb_Map* map = PyUpb_MapContainer_EnsureReified(_self);
const upb_FieldDef* f = PyUpb_MapContainer_GetField(self);
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
const upb_FieldDef* val_f = upb_MessageDef_Field(entry_m, 1);
upb_MessageValue u_key, u_val;
if (!PyUpb_PyToUpb(key, key_f, &u_key, NULL)) return NULL;
if (upb_Map_Get(map, u_key, &u_val)) {
return PyUpb_UpbToPy(u_val, val_f, self->arena);
}

upb_Arena* arena = PyUpb_Arena_Get(self->arena);
if (!PyUpb_PyToUpb(default_value, val_f, &u_val, arena)) return NULL;
if (!PyUpb_MapContainer_Set(self, map, u_key, u_val, arena)) return NULL;

Py_INCREF(default_value);
return default_value;
}

static PyObject* PyUpb_MessageMapContainer_Setdefault(PyObject* self,
PyObject* args) {
PyErr_Format(PyExc_NotImplementedError,
"Set message map value directly is not supported.");
return NULL;
}

static PyObject* PyUpb_MapContainer_Get(PyObject* _self, PyObject* args,
PyObject* kwargs) {
PyUpb_MapContainer* self = (PyUpb_MapContainer*)_self;
Expand Down Expand Up @@ -331,6 +373,9 @@ PyObject* PyUpb_MapContainer_GetOrCreateWrapper(upb_Map* map,
static PyMethodDef PyUpb_ScalarMapContainer_Methods[] = {
{"clear", PyUpb_MapContainer_Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)PyUpb_ScalarMapContainer_Setdefault,
METH_VARARGS,
"If the key does not exist, insert the key, with the specified value"},
{"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"GetEntryClass", PyUpb_MapContainer_GetEntryClass, METH_NOARGS,
Expand Down Expand Up @@ -373,6 +418,8 @@ static PyType_Spec PyUpb_ScalarMapContainer_Spec = {
static PyMethodDef PyUpb_MessageMapContainer_Methods[] = {
{"clear", PyUpb_MapContainer_Clear, METH_NOARGS,
"Removes all elements from the map."},
{"setdefault", (PyCFunction)PyUpb_MessageMapContainer_Setdefault,
METH_VARARGS, "setdefault is disallowed in MessageMap."},
{"get", (PyCFunction)PyUpb_MapContainer_Get, METH_VARARGS | METH_KEYWORDS,
"Gets the value for the given key if present, or otherwise a default"},
{"get_or_create", PyUpb_MapContainer_Subscript, METH_O,
Expand Down Expand Up @@ -480,8 +527,8 @@ bool PyUpb_Map_Init(PyObject* m) {
PyObject* base = GetMutableMappingBase();
if (!base) return false;

const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__",
"pop", "popitem", "update", "setdefault", NULL};
const char* methods[] = {"keys", "items", "values", "__eq__", "__ne__",
"pop", "popitem", "update", NULL};

state->message_map_container_type = PyUpb_AddClassWithRegister(
m, &PyUpb_MessageMapContainer_Spec, base, methods);
Expand Down

0 comments on commit 81da6b9

Please sign in to comment.