From bc16fe8921889319f9d617167ab470043125be4a Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Thu, 12 Dec 2024 20:31:21 -0800 Subject: [PATCH] Treat -0.0 as consistently distinct from 0.0 in pure python The previous treatment was a conformance violation, where implicit present float fields with a non-default value of -0.0 could get dropped. PiperOrigin-RevId: 705728806 --- .../text_format_failure_list_python.txt | 10 ------- python/google/protobuf/internal/decoder.py | 30 ++++++++++++++++--- .../google/protobuf/internal/decoder_test.py | 12 +++++++- .../google/protobuf/internal/message_test.py | 18 ++++++++++- .../protobuf/internal/python_message.py | 7 +++-- .../protobuf/internal/text_format_test.py | 14 +++++++++ python/google/protobuf/text_format.py | 4 ++- 7 files changed, 75 insertions(+), 20 deletions(-) diff --git a/conformance/text_format_failure_list_python.txt b/conformance/text_format_failure_list_python.txt index c44c7357d6404..2a0aed5d5dece 100644 --- a/conformance/text_format_failure_list_python.txt +++ b/conformance/text_format_failure_list_python.txt @@ -1,18 +1,8 @@ # This is the list of text format conformance tests that are known to fail right # now. # TODO: These should be fixed. -Required.*.TextFormatInput.FloatFieldNegativeZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.FloatFieldNegativeZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.FloatFieldNegativeZero_F.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.FloatFieldNegativeZero_F.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.FloatFieldNegativeZero_f.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.FloatFieldNegativeZero_f.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0 Required.*.TextFormatInput.FloatFieldNoNegativeOctal # Should have failed to parse, but didn't. Required.*.TextFormatInput.FloatFieldNoOctal # Should have failed to parse, but didn't. -Required.*.TextFormatInput.NegDoubleFieldLargeNegativeExponentParsesAsNegZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_double: -0 -Required.*.TextFormatInput.NegDoubleFieldLargeNegativeExponentParsesAsNegZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_double: -0 -Required.*.TextFormatInput.NegFloatFieldLargeNegativeExponentParsesAsNegZero.ProtobufOutput # Output was not equivalent to reference message: deleted: optional_float: -0 -Required.*.TextFormatInput.NegFloatFieldLargeNegativeExponentParsesAsNegZero.TextFormatOutput # Output was not equivalent to reference message: deleted: optional_float: -0 Required.*.TextFormatInput.StringLiteralBasicEscapesBytes.ProtobufOutput # Output was not equivalent to reference message: modified: optional_bytes: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\t Required.*.TextFormatInput.StringLiteralBasicEscapesBytes.TextFormatOutput # Output was not equivalent to reference message: modified: optional_bytes: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\t Required.*.TextFormatInput.StringLiteralBasicEscapesString.ProtobufOutput # Output was not equivalent to reference message: modified: optional_string: "\007\010\014\n\r\t\013?\\\'\"" -> "\007\010\014\n\r\ diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index f9e45c53f11bb..91a4c7541e3dc 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -58,6 +58,7 @@ __author__ = 'kenton@google.com (Kenton Varda)' import math +import numbers import struct from google.protobuf import message @@ -71,6 +72,27 @@ _DecodeError = message.DecodeError +def IsDefaultScalarValue(value): + """Returns whether or not a scalar value is the default value of its type. + + Specifically, this should be used to determine presence of implicit-presence + fields, where we disallow custom defaults. + + Args: + value: A scalar value to check. + + Returns: + True if the value is equivalent to a default value, False otherwise. + """ + if isinstance(value, numbers.Number) and math.copysign(1.0, value) < 0: + # Special case for negative zero, where "truthiness" fails to give the right + # answer. + return False + + # Normally, we can just use Python's boolean conversion. + return not value + + def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). @@ -237,7 +259,7 @@ def DecodeField(buffer, pos, end, message, field_dict): (new_value, pos) = decode_value(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') - if clear_if_default and not new_value: + if clear_if_default and IsDefaultScalarValue(new_value): field_dict.pop(key, None) else: field_dict[key] = new_value @@ -478,7 +500,7 @@ def DecodeField(buffer, pos, end, message, field_dict): (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') - if clear_if_default and not enum_value: + if clear_if_default and IsDefaultScalarValue(enum_value): field_dict.pop(key, None) return pos # pylint: disable=protected-access @@ -573,7 +595,7 @@ def DecodeField(buffer, pos, end, message, field_dict): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - if clear_if_default and not size: + if clear_if_default and IsDefaultScalarValue(size): field_dict.pop(key, None) else: field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) @@ -614,7 +636,7 @@ def DecodeField(buffer, pos, end, message, field_dict): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - if clear_if_default and not size: + if clear_if_default and IsDefaultScalarValue(size): field_dict.pop(key, None) else: field_dict[key] = buffer[pos:new_pos].tobytes() diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py index d1ea43baa2062..fc62e9f7f7239 100644 --- a/python/google/protobuf/internal/decoder_test.py +++ b/python/google/protobuf/internal/decoder_test.py @@ -18,13 +18,15 @@ from google.protobuf.internal import testing_refleaks from google.protobuf.internal import wire_format +from google.protobuf.internal import _parameterized + _INPUT_BYTES = b'\x84r\x12' _EXPECTED = (14596, 18) @testing_refleaks.TestCase -class DecoderTest(unittest.TestCase): +class DecoderTest(_parameterized.TestCase): def test_decode_varint_bytes(self): (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0) @@ -128,6 +130,14 @@ def test_unknown_message_set_decoder_mismatched_end_group(self): memoryview(b'\054\014'), ) + @_parameterized.parameters(int(0), float(0.0), False, '') + def test_default_scalar(self, value): + self.assertTrue(decoder.IsDefaultScalarValue(value)) + + @_parameterized.parameters(int(1), float(-0.0), float(1.0), True, 'a') + def test_not_default_scalar(self, value): + self.assertFalse(decoder.IsDefaultScalarValue(value)) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 7212639db684b..c168256b36ea8 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1784,13 +1784,29 @@ def testProto3ParserDropDefaultScalar(self): message_proto2 = unittest_pb2.TestAllTypes() message_proto2.optional_int32 = 0 message_proto2.optional_string = '' + message_proto2.optional_float = 0.0 message_proto2.optional_bytes = b'' - self.assertEqual(len(message_proto2.ListFields()), 3) + self.assertEqual(len(message_proto2.ListFields()), 4) message_proto3 = unittest_proto3_arena_pb2.TestAllTypes() message_proto3.ParseFromString(message_proto2.SerializeToString()) self.assertEqual(len(message_proto3.ListFields()), 0) + def testProto3ParserKeepsNonDefaultScalar(self): + message_proto2 = unittest_pb2.TestAllTypes() + message_proto2.optional_int32 = 1 + message_proto2.optional_string = '\0' + message_proto2.optional_float = -0.0 + message_proto2.optional_double = -0.0 + message_proto2.optional_bytes = b'\0' + self.assertEqual(len(message_proto2.ListFields()), 5) + serialized = message_proto2.SerializeToString() + + message_proto3 = unittest_proto3_arena_pb2.TestAllTypes() + message_proto3.ParseFromString(serialized) + self.assertEqual(len(message_proto3.ListFields()), 5) + self.assertEqual(message_proto3.SerializeToString(), serialized) + def testProto3Optional(self): msg = test_proto3_optional_pb2.TestProto3Optional() self.assertFalse(msg.HasField('optional_int32')) diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 98dea6cb9252e..61766099a03ae 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -29,6 +29,7 @@ import datetime from io import BytesIO +import math import struct import sys import warnings @@ -716,14 +717,14 @@ def getter(self): def field_setter(self, new_value): # pylint: disable=protected-access - # Testing the value for truthiness captures all of the proto3 defaults - # (0, 0.0, enum 0, and False). + # Testing the value for truthiness captures all of the implicit presence + # defaults (0, 0.0, enum 0, and False), except for -0.0. try: new_value = type_checker.CheckValue(new_value) except TypeError as e: raise TypeError( 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) - if not field.has_presence and not new_value: + if not field.has_presence and decoder.IsDefaultScalarValue(new_value): self._fields.pop(field, None) else: self._fields[field] = new_value diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index a08f6701f7c75..1f1b5eb8e7814 100644 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -979,6 +979,20 @@ def testParseDuplicateScalars(self, message_module): r'have multiple "optional_int32" fields.'), text_format.Parse, text, message) + def testParseDuplicateNegativeZero(self, message_module): + message = message_module.TestAllTypes() + text = 'optional_double: -0.0 optional_double: 3' + self.assertRaisesRegex( + text_format.ParseError, + ( + r'1:40 : Message type "\w+.TestAllTypes" should not ' + r'have multiple "optional_double" fields.' + ), + text_format.Parse, + text, + message, + ) + def testParseExistingScalarInMessage(self, message_module): message = message_module.TestAllTypes(optional_int32=42) text = 'optional_int32: 67' diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 1a5c588fea440..ae49e269dcfef 100644 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -1165,7 +1165,9 @@ def _MergeScalarField(self, tokenizer, message, field): else: # For field that doesn't represent presence, try best effort to # check multiple scalars by compare to default values. - duplicate_error = bool(getattr(message, field.name)) + duplicate_error = not decoder.IsDefaultScalarValue( + getattr(message, field.name) + ) if duplicate_error: raise tokenizer.ParseErrorPreviousToken(