From 59d0645a99ae196d283de267e39bb358aa0c024f Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Wed, 19 Feb 2025 12:23:10 -0800 Subject: [PATCH] Add constructing version of Any unpacking to Python Protobuf library. The main benefit here is enabling unpacking in an expression context whereas previously multiple statements were required. The return/raise API is also more idiomatic to Python than the current one using an out parameter. PiperOrigin-RevId: 728769960 --- .../protobuf/internal/well_known_types.py | 119 ++++++--- .../internal/well_known_types_test.py | 233 ++++++++++-------- 2 files changed, 216 insertions(+), 136 deletions(-) diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index 0d95b4b98230b..7b20ac0cf0bf7 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -20,9 +20,12 @@ import calendar import collections.abc import datetime +from typing import TypeVar, Union import warnings +from google.protobuf import message from google.protobuf.internal import field_mask -from typing import Union + +_MessageT = TypeVar('_MessageT', bound=message.Message) FieldMask = field_mask.FieldMask @@ -48,8 +51,9 @@ class Any(object): __slots__ = () - def Pack(self, msg, type_url_prefix='type.googleapis.com/', - deterministic=None): + def Pack( + self, msg, type_url_prefix='type.googleapis.com/', deterministic=None + ): """Packs the specified message into current Any message.""" if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/': self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name) @@ -65,6 +69,28 @@ def Unpack(self, msg): msg.ParseFromString(self.value) return True + def UnpackAs(self, message_type: type[_MessageT]) -> _MessageT: + """Unpacks the Any into a newly created message of the specified type. + + Args: + message_type: + + Returns: + Unpacked message. + + Raises: + TypeError if the specified message type doesn't match the contained + message. + """ + unpacked = message_type() + if self.Unpack(unpacked): + return unpacked + else: + raise TypeError( + f'Attempted to unpack {self.TypeName()} to' + f' {message_type.__qualname__}' + ) + def TypeName(self): """Returns the protobuf type name of the inner message.""" # Only last part is to be used: b/25630112 @@ -113,8 +139,8 @@ def FromJsonString(self, value): Args: value: A date string. Any fractional digits (or none) and any offset are - accepted as long as they fit into nano-seconds precision. - Example of accepted format: '1972-01-01T10:00:20.021-05:00' + accepted as long as they fit into nano-seconds precision. Example of + accepted format: '1972-01-01T10:00:20.021-05:00' Raises: ValueError: On parsing problems. @@ -128,7 +154,8 @@ def FromJsonString(self, value): timezone_offset = value.rfind('-') if timezone_offset == -1: raise ValueError( - 'Failed to parse timestamp: missing valid timezone offset.') + 'Failed to parse timestamp: missing valid timezone offset.' + ) time_value = value[0:timezone_offset] # Parse datetime and nanos. point_position = time_value.find('.') @@ -137,18 +164,20 @@ def FromJsonString(self, value): nano_value = '' else: second_value = time_value[:point_position] - nano_value = time_value[point_position + 1:] + nano_value = time_value[point_position + 1 :] if 't' in second_value: raise ValueError( - 'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', ' - 'lowercase \'t\' is not accepted'.format(second_value)) + "time data '{0}' does not match format '%Y-%m-%dT%H:%M:%S', " + "lowercase 't' is not accepted".format(second_value) + ) date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFORMAT) td = date_object - datetime.datetime(1970, 1, 1) seconds = td.seconds + td.days * _SECONDS_PER_DAY if len(nano_value) > 9: raise ValueError( 'Failed to parse Timestamp: nanos {0} more than ' - '9 fractional digits.'.format(nano_value)) + '9 fractional digits.'.format(nano_value) + ) if nano_value: nanos = round(float('0.' + nano_value) * 1e9) else: @@ -156,18 +185,20 @@ def FromJsonString(self, value): # Parse timezone offsets. if value[timezone_offset] == 'Z': if len(value) != timezone_offset + 1: - raise ValueError('Failed to parse timestamp: invalid trailing' - ' data {0}.'.format(value)) + raise ValueError( + 'Failed to parse timestamp: invalid trailing data {0}.'.format( + value + ) + ) else: timezone = value[timezone_offset:] pos = timezone.find(':') if pos == -1: - raise ValueError( - 'Invalid timezone offset value: {0}.'.format(timezone)) + raise ValueError('Invalid timezone offset value: {0}.'.format(timezone)) if timezone[0] == '+': - seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + seconds -= (int(timezone[1:pos]) * 60 + int(timezone[pos + 1 :])) * 60 else: - seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + seconds += (int(timezone[1:pos]) * 60 + int(timezone[pos + 1 :])) * 60 # Set seconds and nanos _CheckTimestampValid(seconds, nanos) self.seconds = int(seconds) @@ -185,14 +216,16 @@ def ToNanoseconds(self): def ToMicroseconds(self): """Converts Timestamp to microseconds since epoch.""" _CheckTimestampValid(self.seconds, self.nanos) - return (self.seconds * _MICROS_PER_SECOND + - self.nanos // _NANOS_PER_MICROSECOND) + return ( + self.seconds * _MICROS_PER_SECOND + self.nanos // _NANOS_PER_MICROSECOND + ) def ToMilliseconds(self): """Converts Timestamp to milliseconds since epoch.""" _CheckTimestampValid(self.seconds, self.nanos) - return (self.seconds * _MILLIS_PER_SECOND + - self.nanos // _NANOS_PER_MILLISECOND) + return ( + self.seconds * _MILLIS_PER_SECOND + self.nanos // _NANOS_PER_MILLISECOND + ) def ToSeconds(self): """Converts Timestamp to seconds since epoch.""" @@ -312,7 +345,8 @@ def _CheckTimestampValid(seconds, nanos): if nanos < 0 or nanos >= _NANOS_PER_SECOND: raise ValueError( 'Timestamp is not valid: Nanos {} must be in a range ' - '[0, 999999].'.format(nanos)) + '[0, 999999].'.format(nanos) + ) class Duration(object): @@ -332,7 +366,7 @@ def ToJsonString(self): _CheckDurationValid(self.seconds, self.nanos) if self.seconds < 0 or self.nanos < 0: result = '-' - seconds = - self.seconds + int((0 - self.nanos) // 1e9) + seconds = -self.seconds + int((0 - self.nanos) // 1e9) nanos = (0 - self.nanos) % 1e9 else: result = '' @@ -357,8 +391,8 @@ def FromJsonString(self, value): Args: value: A string to be converted. The string must end with 's'. Any - fractional digits (or none) are accepted as long as they fit into - precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s + fractional digits (or none) are accepted as long as they fit into + precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s Raises: ValueError: On parsing problems. @@ -366,8 +400,7 @@ def FromJsonString(self, value): if not isinstance(value, str): raise ValueError('Duration JSON value not a string: {!r}'.format(value)) if len(value) < 1 or value[-1] != 's': - raise ValueError( - 'Duration must end with letter "s": {0}.'.format(value)) + raise ValueError('Duration must end with letter "s": {0}.'.format(value)) try: pos = value.find('.') if pos == -1: @@ -376,15 +409,14 @@ def FromJsonString(self, value): else: seconds = int(value[:pos]) if value[0] == '-': - nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('-0{0}'.format(value[pos:-1])) * 1e9)) else: - nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + nanos = int(round(float('0{0}'.format(value[pos:-1])) * 1e9)) _CheckDurationValid(seconds, nanos) self.seconds = seconds self.nanos = nanos except ValueError as e: - raise ValueError( - 'Couldn\'t parse duration: {0} : {1}.'.format(value, e)) + raise ValueError("Couldn't parse duration: {0} : {1}.".format(value, e)) def ToNanoseconds(self): """Converts a Duration to nanoseconds.""" @@ -406,20 +438,23 @@ def ToSeconds(self): def FromNanoseconds(self, nanos): """Converts nanoseconds to Duration.""" - self._NormalizeDuration(nanos // _NANOS_PER_SECOND, - nanos % _NANOS_PER_SECOND) + self._NormalizeDuration( + nanos // _NANOS_PER_SECOND, nanos % _NANOS_PER_SECOND + ) def FromMicroseconds(self, micros): """Converts microseconds to Duration.""" self._NormalizeDuration( micros // _MICROS_PER_SECOND, - (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) + (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND, + ) def FromMilliseconds(self, millis): """Converts milliseconds to Duration.""" self._NormalizeDuration( millis // _MILLIS_PER_SECOND, - (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) + (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND, + ) def FromSeconds(self, seconds): """Converts seconds to Duration.""" @@ -429,8 +464,9 @@ def FromSeconds(self, seconds): def ToTimedelta(self) -> datetime.timedelta: """Converts Duration to timedelta.""" return datetime.timedelta( - seconds=self.seconds, microseconds=_RoundTowardZero( - self.nanos, _NANOS_PER_MICROSECOND)) + seconds=self.seconds, + microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND), + ) def FromTimedelta(self, td): """Converts timedelta to Duration.""" @@ -472,14 +508,15 @@ def _CheckDurationValid(seconds, nanos): if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: raise ValueError( 'Duration is not valid: Seconds {0} must be in range ' - '[-315576000000, 315576000000].'.format(seconds)) + '[-315576000000, 315576000000].'.format(seconds) + ) if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND: raise ValueError( 'Duration is not valid: Nanos {0} must be in range ' - '[-999999999, 999999999].'.format(nanos)) + '[-999999999, 999999999].'.format(nanos) + ) if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0): - raise ValueError( - 'Duration is not valid: Sign mismatch.') + raise ValueError('Duration is not valid: Sign mismatch.') def _RoundTowardZero(value, divider): @@ -601,6 +638,7 @@ def update(self, dictionary): # pylint: disable=invalid-name for key, value in dictionary.items(): _SetStructValue(self.fields[key], value) + collections.abc.MutableMapping.register(Struct) @@ -663,6 +701,7 @@ def add_list(self): list_value.Clear() return list_value + collections.abc.MutableSequence.register(ListValue) diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 683c2a5eee389..fb0251c2f80dd 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -29,6 +29,7 @@ try: # New module in Python 3.9: import zoneinfo # pylint:disable=g-import-not-at-top + _TZ_JAPAN = zoneinfo.ZoneInfo('Japan') _TZ_PACIFIC = zoneinfo.ZoneInfo('US/Pacific') has_zoneinfo = True @@ -124,43 +125,35 @@ def testDurationSerializeAndParse(self): def testTimestampIntegerConversion(self): message = timestamp_pb2.Timestamp() message.FromNanoseconds(1) - self.assertEqual('1970-01-01T00:00:00.000000001Z', - message.ToJsonString()) + self.assertEqual('1970-01-01T00:00:00.000000001Z', message.ToJsonString()) self.assertEqual(1, message.ToNanoseconds()) message.FromNanoseconds(-1) - self.assertEqual('1969-12-31T23:59:59.999999999Z', - message.ToJsonString()) + self.assertEqual('1969-12-31T23:59:59.999999999Z', message.ToJsonString()) self.assertEqual(-1, message.ToNanoseconds()) message.FromMicroseconds(1) - self.assertEqual('1970-01-01T00:00:00.000001Z', - message.ToJsonString()) + self.assertEqual('1970-01-01T00:00:00.000001Z', message.ToJsonString()) self.assertEqual(1, message.ToMicroseconds()) message.FromMicroseconds(-1) - self.assertEqual('1969-12-31T23:59:59.999999Z', - message.ToJsonString()) + self.assertEqual('1969-12-31T23:59:59.999999Z', message.ToJsonString()) self.assertEqual(-1, message.ToMicroseconds()) message.FromMilliseconds(1) - self.assertEqual('1970-01-01T00:00:00.001Z', - message.ToJsonString()) + self.assertEqual('1970-01-01T00:00:00.001Z', message.ToJsonString()) self.assertEqual(1, message.ToMilliseconds()) message.FromMilliseconds(-1) - self.assertEqual('1969-12-31T23:59:59.999Z', - message.ToJsonString()) + self.assertEqual('1969-12-31T23:59:59.999Z', message.ToJsonString()) self.assertEqual(-1, message.ToMilliseconds()) message.FromSeconds(1) - self.assertEqual('1970-01-01T00:00:01Z', - message.ToJsonString()) + self.assertEqual('1970-01-01T00:00:01Z', message.ToJsonString()) self.assertEqual(1, message.ToSeconds()) message.FromSeconds(-1) - self.assertEqual('1969-12-31T23:59:59Z', - message.ToJsonString()) + self.assertEqual('1969-12-31T23:59:59Z', message.ToJsonString()) self.assertEqual(-1, message.ToSeconds()) message.FromNanoseconds(1999) @@ -175,33 +168,27 @@ def testTimestampIntegerConversion(self): def testDurationIntegerConversion(self): message = duration_pb2.Duration() message.FromNanoseconds(1) - self.assertEqual('0.000000001s', - message.ToJsonString()) + self.assertEqual('0.000000001s', message.ToJsonString()) self.assertEqual(1, message.ToNanoseconds()) message.FromNanoseconds(-1) - self.assertEqual('-0.000000001s', - message.ToJsonString()) + self.assertEqual('-0.000000001s', message.ToJsonString()) self.assertEqual(-1, message.ToNanoseconds()) message.FromMicroseconds(1) - self.assertEqual('0.000001s', - message.ToJsonString()) + self.assertEqual('0.000001s', message.ToJsonString()) self.assertEqual(1, message.ToMicroseconds()) message.FromMicroseconds(-1) - self.assertEqual('-0.000001s', - message.ToJsonString()) + self.assertEqual('-0.000001s', message.ToJsonString()) self.assertEqual(-1, message.ToMicroseconds()) message.FromMilliseconds(1) - self.assertEqual('0.001s', - message.ToJsonString()) + self.assertEqual('0.001s', message.ToJsonString()) self.assertEqual(1, message.ToMilliseconds()) message.FromMilliseconds(-1) - self.assertEqual('-0.001s', - message.ToJsonString()) + self.assertEqual('-0.001s', message.ToJsonString()) self.assertEqual(-1, message.ToMilliseconds()) message.FromSeconds(1) @@ -209,8 +196,7 @@ def testDurationIntegerConversion(self): self.assertEqual(1, message.ToSeconds()) message.FromSeconds(-1) - self.assertEqual('-1s', - message.ToJsonString()) + self.assertEqual('-1s', message.ToJsonString()) self.assertEqual(-1, message.ToSeconds()) # Test truncation behavior. @@ -239,8 +225,9 @@ def testTimezoneNaiveDatetimeConversionNearEpoch(self): self.assertEqual(1, message.seconds) self.assertEqual(999_000_000, message.nanos) - self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 1, 999000), - message.ToDatetime()) + self.assertEqual( + datetime.datetime(1970, 1, 1, 0, 0, 1, 999000), message.ToDatetime() + ) def testTimezoneNaiveDatetimeConversionWhereTimestampLosesPrecision(self): ts = timestamp_pb2.Timestamp() @@ -298,7 +285,8 @@ def testTimezoneAwareDatetimeConversion(self, date_parts, tzinfo): self.assertEqual(original_datetime, aware_datetime) self.assertEqual( datetime.datetime(1970, 1, 1, 2, tzinfo=datetime.timezone.utc), - aware_datetime) + aware_datetime, + ) self.assertEqual(tzinfo, aware_datetime.tzinfo) @unittest.skipIf( @@ -476,13 +464,11 @@ def testTimestampSub(self, old_time, time_delta, expected_sec, expected_nano): def testNanosOneSecond(self): tz = _TZ_PACIFIC ts = timestamp_pb2.Timestamp(nanos=1_000_000_000) - self.assertRaisesRegex(ValueError, 'Timestamp is not valid', - ts.ToDatetime) + self.assertRaisesRegex(ValueError, 'Timestamp is not valid', ts.ToDatetime) def testNanosNegativeOneSecond(self): ts = timestamp_pb2.Timestamp(nanos=-1_000_000_000) - self.assertRaisesRegex(ValueError, 'Timestamp is not valid', - ts.ToDatetime) + self.assertRaisesRegex(ValueError, 'Timestamp is not valid', ts.ToDatetime) def testTimedeltaConversion(self): message = duration_pb2.Duration() @@ -509,27 +495,43 @@ def testTimedeltaConversion(self): def testInvalidTimestamp(self): message = timestamp_pb2.Timestamp() self.assertRaisesRegex( - ValueError, 'Failed to parse timestamp: missing valid timezone offset.', - message.FromJsonString, '') + ValueError, + 'Failed to parse timestamp: missing valid timezone offset.', + message.FromJsonString, + '', + ) self.assertRaisesRegex( - ValueError, 'Failed to parse timestamp: invalid trailing data ' - '1970-01-01T00:00:01Ztrail.', message.FromJsonString, - '1970-01-01T00:00:01Ztrail') + ValueError, + 'Failed to parse timestamp: invalid trailing data ' + '1970-01-01T00:00:01Ztrail.', + message.FromJsonString, + '1970-01-01T00:00:01Ztrail', + ) self.assertRaisesRegex( - ValueError, 'time data \'10000-01-01T00:00:00\' does not match' - ' format \'%Y-%m-%dT%H:%M:%S\'', message.FromJsonString, - '10000-01-01T00:00:00.00Z') + ValueError, + "time data '10000-01-01T00:00:00' does not match" + " format '%Y-%m-%dT%H:%M:%S'", + message.FromJsonString, + '10000-01-01T00:00:00.00Z', + ) self.assertRaisesRegex( - ValueError, 'nanos 0123456789012 more than 9 fractional digits.', - message.FromJsonString, '1970-01-01T00:00:00.0123456789012Z') + ValueError, + 'nanos 0123456789012 more than 9 fractional digits.', + message.FromJsonString, + '1970-01-01T00:00:00.0123456789012Z', + ) self.assertRaisesRegex( ValueError, - (r'Invalid timezone offset value: \+08.'), + r'Invalid timezone offset value: \+08.', message.FromJsonString, '1972-01-01T01:00:00.01+08', ) - self.assertRaisesRegex(ValueError, 'year (0 )?is out of range', - message.FromJsonString, '0000-01-01T00:00:00Z') + self.assertRaisesRegex( + ValueError, + 'year (0 )?is out of range', + message.FromJsonString, + '0000-01-01T00:00:00Z', + ) message.seconds = 253402300800 self.assertRaisesRegex(ValueError, 'Timestamp is not valid', message.ToJsonString) @@ -550,36 +552,57 @@ def testInvalidTimestamp(self): def testInvalidDuration(self): message = duration_pb2.Duration() - self.assertRaisesRegex(ValueError, 'Duration must end with letter "s": 1.', - message.FromJsonString, '1') - self.assertRaisesRegex(ValueError, 'Couldn\'t parse duration: 1...2s.', - message.FromJsonString, '1...2s') + self.assertRaisesRegex( + ValueError, + 'Duration must end with letter "s": 1.', + message.FromJsonString, + '1', + ) + self.assertRaisesRegex( + ValueError, + "Couldn't parse duration: 1...2s.", + message.FromJsonString, + '1...2s', + ) text = '-315576000001.000000000s' self.assertRaisesRegex( ValueError, r'Duration is not valid\: Seconds -315576000001 must be in range' - r' \[-315576000000\, 315576000000\].', message.FromJsonString, text) + r' \[-315576000000\, 315576000000\].', + message.FromJsonString, + text, + ) text = '315576000001.000000000s' self.assertRaisesRegex( ValueError, r'Duration is not valid\: Seconds 315576000001 must be in range' - r' \[-315576000000\, 315576000000\].', message.FromJsonString, text) + r' \[-315576000000\, 315576000000\].', + message.FromJsonString, + text, + ) message.seconds = -315576000001 message.nanos = 0 self.assertRaisesRegex( ValueError, r'Duration is not valid\: Seconds -315576000001 must be in range' - r' \[-315576000000\, 315576000000\].', message.ToJsonString) + r' \[-315576000000\, 315576000000\].', + message.ToJsonString, + ) message.seconds = 0 message.nanos = 999999999 + 1 self.assertRaisesRegex( - ValueError, r'Duration is not valid\: Nanos 1000000000 must be in range' - r' \[-999999999\, 999999999\].', message.ToJsonString) + ValueError, + r'Duration is not valid\: Nanos 1000000000 must be in range' + r' \[-999999999\, 999999999\].', + message.ToJsonString, + ) message.seconds = -1 message.nanos = 1 - self.assertRaisesRegex(ValueError, - r'Duration is not valid\: Sign mismatch.', - message.ToJsonString) + self.assertRaisesRegex( + ValueError, + r'Duration is not valid\: Sign mismatch.', + message.ToJsonString, + ) msg = well_known_types_test_pb2.WKTMessage() with self.assertRaises(AttributeError): msg.optional_duration = 1 @@ -708,8 +731,10 @@ def testStruct(self): self.assertEqual(11, struct['key4']['subkey']) inner_struct = struct_class() inner_struct['subkey2'] = 9 - self.assertEqual([6, 'seven', True, False, None, inner_struct], - list(struct['key5'].items())) + self.assertEqual( + [6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items()), + ) self.assertEqual({}, dict(struct['key6']['subkey'].fields)) self.assertEqual([2, False], list(struct['key7'].items())) @@ -738,8 +763,10 @@ def testStruct(self): self.assertEqual('abc', struct2['key2']) self.assertIs(True, struct2['key3']) self.assertEqual(11, struct2['key4']['subkey']) - self.assertEqual([6, 'seven', True, False, None, inner_struct], - list(struct2['key5'].items())) + self.assertEqual( + [6, 'seven', True, False, None, inner_struct], + list(struct2['key5'].items()), + ) struct_list = struct2['key5'] self.assertEqual(6, struct_list[0]) @@ -753,8 +780,9 @@ def testStruct(self): self.assertEqual(7, struct_list[1]) struct_list.add_list().extend([1, 'two', True, False, None]) - self.assertEqual([1, 'two', True, False, None], - list(struct_list[6].items())) + self.assertEqual( + [1, 'two', True, False, None], list(struct_list[6].items()) + ) struct_list.extend([{'nested_struct': 30}, ['nested_list', 99], {}, []]) self.assertEqual(11, len(struct_list.values)) self.assertEqual(30, struct_list[7]['nested_struct']) @@ -799,8 +827,9 @@ def testStruct(self): self.assertEqual(6, len(struct['key5'])) del struct['key5'][1] self.assertEqual(5, len(struct['key5'])) - self.assertEqual([6, True, False, None, inner_struct], - list(struct['key5'].items())) + self.assertEqual( + [6, True, False, None, inner_struct], list(struct['key5'].items()) + ) def testInOperator(self): # in operator for Struct @@ -921,7 +950,7 @@ def testMergeFrom(self): 'key6': [['nested_list', True]], 'empty_struct': {}, 'empty_list': [], - 'tuple': ((3,2), ()) + 'tuple': ((3, 2), ()), } struct.update(dictionary) self.assertEqual(5, struct['key1']) @@ -930,8 +959,10 @@ def testMergeFrom(self): self.assertEqual(11, struct['key4']['subkey']) inner_struct = struct_class() inner_struct['subkey2'] = 9 - self.assertEqual([6, 'seven', True, False, None, inner_struct], - list(struct['key5'].items())) + self.assertEqual( + [6, 'seven', True, False, None, inner_struct], + list(struct['key5'].items()), + ) self.assertEqual(2, len(struct['key6'][0].values)) self.assertEqual('nested_list', struct['key6'][0][0]) self.assertEqual(True, struct['key6'][0][1]) @@ -942,10 +973,7 @@ def testMergeFrom(self): # According to documentation: "When parsing from the wire or when merging, # if there are duplicate map keys the last key seen is used". - duplicate = { - 'key4': {'replace': 20}, - 'key5': [[False, 5]] - } + duplicate = {'key4': {'replace': 20}, 'key5': [[False, 5]]} struct.update(duplicate) self.assertEqual(1, len(struct['key4'].fields)) self.assertEqual(20, struct['key4']['replace']) @@ -962,30 +990,40 @@ def testAnyMessage(self): msg_descriptor = msg.DESCRIPTOR all_types = unittest_pb2.TestAllTypes() all_descriptor = all_types.DESCRIPTOR - all_types.repeated_string.append(u'\u00fc\ua71f') + all_types.repeated_string.append('\u00fc\ua71f') # Packs to Any. msg.value.Pack(all_types) - self.assertEqual(msg.value.type_url, - 'type.googleapis.com/%s' % all_descriptor.full_name) - self.assertEqual(msg.value.value, - all_types.SerializeToString()) + self.assertEqual( + msg.value.type_url, 'type.googleapis.com/%s' % all_descriptor.full_name + ) + self.assertEqual(msg.value.value, all_types.SerializeToString()) # Tests Is() method. self.assertTrue(msg.value.Is(all_descriptor)) self.assertFalse(msg.value.Is(msg_descriptor)) + # Unpacks Any. unpacked_message = unittest_pb2.TestAllTypes() self.assertTrue(msg.value.Unpack(unpacked_message)) self.assertEqual(all_types, unpacked_message) - # Unpacks to different type. + + self.assertEqual(msg.value.UnpackAs(unittest_pb2.TestAllTypes), all_types) + + # Fails to unpack to a different type. self.assertFalse(msg.value.Unpack(msg)) + + with self.assertRaises(TypeError) as catcher: + msg.value.UnpackAs(well_known_types_test_pb2.TestAny) + self.assertIn('Attempted to unpack', catcher.exception.args[0]) + # Only Any messages have Pack method. try: msg.Pack(all_types) except AttributeError: pass else: - raise AttributeError('%s should not have Pack method.' % - msg_descriptor.full_name) + raise AttributeError( + '%s should not have Pack method.' % msg_descriptor.full_name + ) def testUnpackWithNoSlashInTypeUrl(self): msg = well_known_types_test_pb2.TestAny() @@ -1012,16 +1050,17 @@ def testPackWithCustomTypeUrl(self): msg = any_pb2.Any() # Pack with a custom type URL prefix. msg.Pack(submessage, 'type.myservice.com') - self.assertEqual(msg.type_url, - 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + self.assertEqual( + msg.type_url, 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name + ) # Pack with a custom type URL prefix ending with '/'. msg.Pack(submessage, 'type.myservice.com/') - self.assertEqual(msg.type_url, - 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name) + self.assertEqual( + msg.type_url, 'type.myservice.com/%s' % submessage.DESCRIPTOR.full_name + ) # Pack with an empty type URL prefix. msg.Pack(submessage, '') - self.assertEqual(msg.type_url, - '/%s' % submessage.DESCRIPTOR.full_name) + self.assertEqual(msg.type_url, '/%s' % submessage.DESCRIPTOR.full_name) # Test unpacking the type. unpacked_message = well_known_types_test_pb2.TestAny() self.assertTrue(msg.Unpack(unpacked_message)) @@ -1034,11 +1073,13 @@ def testPackDeterministic(self): msg = any_pb2.Any() msg.Pack(submessage, deterministic=True) serialized = msg.SerializeToString(deterministic=True) - golden = (b'\n4type.googleapis.com/google.protobuf.internal.TestAny\x12F' - b'\x1a\x05\n\x010\x10\x00\x1a\x05\n\x011\x10\x02\x1a\x05\n\x01' - b'2\x10\x04\x1a\x05\n\x013\x10\x06\x1a\x05\n\x014\x10\x08\x1a' - b'\x05\n\x015\x10\n\x1a\x05\n\x016\x10\x0c\x1a\x05\n\x017\x10' - b'\x0e\x1a\x05\n\x018\x10\x10\x1a\x05\n\x019\x10\x12') + golden = ( + b'\n4type.googleapis.com/google.protobuf.internal.TestAny\x12F' + b'\x1a\x05\n\x010\x10\x00\x1a\x05\n\x011\x10\x02\x1a\x05\n\x01' + b'2\x10\x04\x1a\x05\n\x013\x10\x06\x1a\x05\n\x014\x10\x08\x1a' + b'\x05\n\x015\x10\n\x1a\x05\n\x016\x10\x0c\x1a\x05\n\x017\x10' + b'\x0e\x1a\x05\n\x018\x10\x10\x1a\x05\n\x019\x10\x12' + ) self.assertEqual(golden, serialized) def testJsonStruct(self):