Skip to content

Commit

Permalink
Add constructing version of Any unpacking to Python Protobuf library.
Browse files Browse the repository at this point in the history
The main benefit here is enabling unpacking in an expression context whereas previously multiple statements were required. The return/raise API is also more idiomatic to Python than the current one using an out parameter.

PiperOrigin-RevId: 728769960
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Feb 20, 2025
1 parent fb9e2c3 commit 59d0645
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 136 deletions.
119 changes: 79 additions & 40 deletions python/google/protobuf/internal/well_known_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import calendar
import collections.abc
import datetime
from typing import TypeVar, Union
import warnings
from google.protobuf import message
from google.protobuf.internal import field_mask
from typing import Union

_MessageT = TypeVar('_MessageT', bound=message.Message)

FieldMask = field_mask.FieldMask

Expand All @@ -48,8 +51,9 @@ class Any(object):

__slots__ = ()

def Pack(self, msg, type_url_prefix='type.googleapis.com/',
deterministic=None):
def Pack(
self, msg, type_url_prefix='type.googleapis.com/', deterministic=None
):
"""Packs the specified message into current Any message."""
if len(type_url_prefix) < 1 or type_url_prefix[-1] != '/':
self.type_url = '%s/%s' % (type_url_prefix, msg.DESCRIPTOR.full_name)
Expand All @@ -65,6 +69,28 @@ def Unpack(self, msg):
msg.ParseFromString(self.value)
return True

def UnpackAs(self, message_type: type[_MessageT]) -> _MessageT:
"""Unpacks the Any into a newly created message of the specified type.
Args:
message_type:
Returns:
Unpacked message.
Raises:
TypeError if the specified message type doesn't match the contained
message.
"""
unpacked = message_type()
if self.Unpack(unpacked):
return unpacked
else:
raise TypeError(
f'Attempted to unpack {self.TypeName()} to'
f' {message_type.__qualname__}'
)

def TypeName(self):
"""Returns the protobuf type name of the inner message."""
# Only last part is to be used: b/25630112
Expand Down Expand Up @@ -113,8 +139,8 @@ def FromJsonString(self, value):
Args:
value: A date string. Any fractional digits (or none) and any offset are
accepted as long as they fit into nano-seconds precision.
Example of accepted format: '1972-01-01T10:00:20.021-05:00'
accepted as long as they fit into nano-seconds precision. Example of
accepted format: '1972-01-01T10:00:20.021-05:00'
Raises:
ValueError: On parsing problems.
Expand All @@ -128,7 +154,8 @@ def FromJsonString(self, value):
timezone_offset = value.rfind('-')
if timezone_offset == -1:
raise ValueError(
'Failed to parse timestamp: missing valid timezone offset.')
'Failed to parse timestamp: missing valid timezone offset.'
)
time_value = value[0:timezone_offset]
# Parse datetime and nanos.
point_position = time_value.find('.')
Expand All @@ -137,37 +164,41 @@ def FromJsonString(self, value):
nano_value = ''
else:
second_value = time_value[:point_position]
nano_value = time_value[point_position + 1:]
nano_value = time_value[point_position + 1 :]
if 't' in second_value:
raise ValueError(
'time data \'{0}\' does not match format \'%Y-%m-%dT%H:%M:%S\', '
'lowercase \'t\' is not accepted'.format(second_value))
"time data '{0}' does not match format '%Y-%m-%dT%H:%M:%S', "
"lowercase 't' is not accepted".format(second_value)
)
date_object = datetime.datetime.strptime(second_value, _TIMESTAMPFORMAT)
td = date_object - datetime.datetime(1970, 1, 1)
seconds = td.seconds + td.days * _SECONDS_PER_DAY
if len(nano_value) > 9:
raise ValueError(
'Failed to parse Timestamp: nanos {0} more than '
'9 fractional digits.'.format(nano_value))
'9 fractional digits.'.format(nano_value)
)
if nano_value:
nanos = round(float('0.' + nano_value) * 1e9)
else:
nanos = 0
# Parse timezone offsets.
if value[timezone_offset] == 'Z':
if len(value) != timezone_offset + 1:
raise ValueError('Failed to parse timestamp: invalid trailing'
' data {0}.'.format(value))
raise ValueError(
'Failed to parse timestamp: invalid trailing data {0}.'.format(
value
)
)
else:
timezone = value[timezone_offset:]
pos = timezone.find(':')
if pos == -1:
raise ValueError(
'Invalid timezone offset value: {0}.'.format(timezone))
raise ValueError('Invalid timezone offset value: {0}.'.format(timezone))
if timezone[0] == '+':
seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
seconds -= (int(timezone[1:pos]) * 60 + int(timezone[pos + 1 :])) * 60
else:
seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60
seconds += (int(timezone[1:pos]) * 60 + int(timezone[pos + 1 :])) * 60
# Set seconds and nanos
_CheckTimestampValid(seconds, nanos)
self.seconds = int(seconds)
Expand All @@ -185,14 +216,16 @@ def ToNanoseconds(self):
def ToMicroseconds(self):
"""Converts Timestamp to microseconds since epoch."""
_CheckTimestampValid(self.seconds, self.nanos)
return (self.seconds * _MICROS_PER_SECOND +
self.nanos // _NANOS_PER_MICROSECOND)
return (
self.seconds * _MICROS_PER_SECOND + self.nanos // _NANOS_PER_MICROSECOND
)

def ToMilliseconds(self):
"""Converts Timestamp to milliseconds since epoch."""
_CheckTimestampValid(self.seconds, self.nanos)
return (self.seconds * _MILLIS_PER_SECOND +
self.nanos // _NANOS_PER_MILLISECOND)
return (
self.seconds * _MILLIS_PER_SECOND + self.nanos // _NANOS_PER_MILLISECOND
)

def ToSeconds(self):
"""Converts Timestamp to seconds since epoch."""
Expand Down Expand Up @@ -312,7 +345,8 @@ def _CheckTimestampValid(seconds, nanos):
if nanos < 0 or nanos >= _NANOS_PER_SECOND:
raise ValueError(
'Timestamp is not valid: Nanos {} must be in a range '
'[0, 999999].'.format(nanos))
'[0, 999999].'.format(nanos)
)


class Duration(object):
Expand All @@ -332,7 +366,7 @@ def ToJsonString(self):
_CheckDurationValid(self.seconds, self.nanos)
if self.seconds < 0 or self.nanos < 0:
result = '-'
seconds = - self.seconds + int((0 - self.nanos) // 1e9)
seconds = -self.seconds + int((0 - self.nanos) // 1e9)
nanos = (0 - self.nanos) % 1e9
else:
result = ''
Expand All @@ -357,17 +391,16 @@ def FromJsonString(self, value):
Args:
value: A string to be converted. The string must end with 's'. Any
fractional digits (or none) are accepted as long as they fit into
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
fractional digits (or none) are accepted as long as they fit into
precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s
Raises:
ValueError: On parsing problems.
"""
if not isinstance(value, str):
raise ValueError('Duration JSON value not a string: {!r}'.format(value))
if len(value) < 1 or value[-1] != 's':
raise ValueError(
'Duration must end with letter "s": {0}.'.format(value))
raise ValueError('Duration must end with letter "s": {0}.'.format(value))
try:
pos = value.find('.')
if pos == -1:
Expand All @@ -376,15 +409,14 @@ def FromJsonString(self, value):
else:
seconds = int(value[:pos])
if value[0] == '-':
nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9))
nanos = int(round(float('-0{0}'.format(value[pos:-1])) * 1e9))
else:
nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9))
nanos = int(round(float('0{0}'.format(value[pos:-1])) * 1e9))
_CheckDurationValid(seconds, nanos)
self.seconds = seconds
self.nanos = nanos
except ValueError as e:
raise ValueError(
'Couldn\'t parse duration: {0} : {1}.'.format(value, e))
raise ValueError("Couldn't parse duration: {0} : {1}.".format(value, e))

def ToNanoseconds(self):
"""Converts a Duration to nanoseconds."""
Expand All @@ -406,20 +438,23 @@ def ToSeconds(self):

def FromNanoseconds(self, nanos):
"""Converts nanoseconds to Duration."""
self._NormalizeDuration(nanos // _NANOS_PER_SECOND,
nanos % _NANOS_PER_SECOND)
self._NormalizeDuration(
nanos // _NANOS_PER_SECOND, nanos % _NANOS_PER_SECOND
)

def FromMicroseconds(self, micros):
"""Converts microseconds to Duration."""
self._NormalizeDuration(
micros // _MICROS_PER_SECOND,
(micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND)
(micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND,
)

def FromMilliseconds(self, millis):
"""Converts milliseconds to Duration."""
self._NormalizeDuration(
millis // _MILLIS_PER_SECOND,
(millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND)
(millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND,
)

def FromSeconds(self, seconds):
"""Converts seconds to Duration."""
Expand All @@ -429,8 +464,9 @@ def FromSeconds(self, seconds):
def ToTimedelta(self) -> datetime.timedelta:
"""Converts Duration to timedelta."""
return datetime.timedelta(
seconds=self.seconds, microseconds=_RoundTowardZero(
self.nanos, _NANOS_PER_MICROSECOND))
seconds=self.seconds,
microseconds=_RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND),
)

def FromTimedelta(self, td):
"""Converts timedelta to Duration."""
Expand Down Expand Up @@ -472,14 +508,15 @@ def _CheckDurationValid(seconds, nanos):
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:
raise ValueError(
'Duration is not valid: Seconds {0} must be in range '
'[-315576000000, 315576000000].'.format(seconds))
'[-315576000000, 315576000000].'.format(seconds)
)
if nanos <= -_NANOS_PER_SECOND or nanos >= _NANOS_PER_SECOND:
raise ValueError(
'Duration is not valid: Nanos {0} must be in range '
'[-999999999, 999999999].'.format(nanos))
'[-999999999, 999999999].'.format(nanos)
)
if (nanos < 0 and seconds > 0) or (nanos > 0 and seconds < 0):
raise ValueError(
'Duration is not valid: Sign mismatch.')
raise ValueError('Duration is not valid: Sign mismatch.')


def _RoundTowardZero(value, divider):
Expand Down Expand Up @@ -601,6 +638,7 @@ def update(self, dictionary): # pylint: disable=invalid-name
for key, value in dictionary.items():
_SetStructValue(self.fields[key], value)


collections.abc.MutableMapping.register(Struct)


Expand Down Expand Up @@ -663,6 +701,7 @@ def add_list(self):
list_value.Clear()
return list_value


collections.abc.MutableSequence.register(ListValue)


Expand Down
Loading

0 comments on commit 59d0645

Please sign in to comment.