From b690e729eb86d9f505d1c14692100bfc39809603 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Wed, 5 Jun 2024 13:52:52 -0700 Subject: [PATCH] Nextgen Proto Pythonic API: Timestamp/Duration assignment, creation and calculation Timestamp and Duration are now have more support with datetime and timedelta: - Allows assign python datetime to protobuf DateTime field in addition to current FromDatetime/ToDatetime (Note: will throw exceptions for the differences in supported ranges) - Allows assign python timedelta to protobuf Duration field in addition to current FromTimedelta/ToTimedelta - Calculation between Timestamp, Duration, datetime and timedelta will also be supported. example usage: from datetime import datetime, timedelta from event_pb2 import Event e = Event(start_time=datetime(year=2112, month=2, day=3), duration=timedelta(hours=10)) duration = timedelta(hours=10)) end_time = e.start_time + timedelta(hours=4) e.duration = end_time - e.start_time PiperOrigin-RevId: 640639168 --- .../protobuf/internal/descriptor_pool_test.py | 10 + .../protobuf/internal/more_messages.proto | 8 + .../protobuf/internal/python_message.py | 43 +++- .../protobuf/internal/well_known_types.py | 51 +++- .../internal/well_known_types_test.py | 236 +++++++++++++++++- python/google/protobuf/pyext/message.cc | 71 +++++- python/message.c | 57 ++++- 7 files changed, 445 insertions(+), 31 deletions(-) diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index c903bcc702446..222a1e3a0c017 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -29,6 +29,8 @@ from google.protobuf.internal import no_package_pb2 from google.protobuf.internal import testing_refleaks +from google.protobuf import duration_pb2 +from google.protobuf import timestamp_pb2 from google.protobuf import unittest_features_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 @@ -435,6 +437,8 @@ def testAddSerializedFile(self): self.assertEqual(file2.name, 'google/protobuf/internal/factory_test2.proto') self.testFindMessageTypeByName() + self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) + self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb) file_json = self.pool.AddSerializedFile( more_messages_pb2.DESCRIPTOR.serialized_pb) field = file_json.message_types_by_name['class'].fields_by_name['int_field'] @@ -542,12 +546,18 @@ def testComplexNesting(self): # that uses a DescriptorDatabase. # TODO: Fix python and cpp extension diff. return + timestamp_desc = descriptor_pb2.FileDescriptorProto.FromString( + timestamp_pb2.DESCRIPTOR.serialized_pb) + duration_desc = descriptor_pb2.FileDescriptorProto.FromString( + duration_pb2.DESCRIPTOR.serialized_pb) more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( more_messages_pb2.DESCRIPTOR.serialized_pb) test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) test2_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(timestamp_desc) + self.pool.Add(duration_desc) self.pool.Add(more_messages_desc) self.pool.Add(test1_desc) self.pool.Add(test2_desc) diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto index 0c0505f620153..b290685ee9f74 100644 --- a/python/google/protobuf/internal/more_messages.proto +++ b/python/google/protobuf/internal/more_messages.proto @@ -13,6 +13,9 @@ syntax = "proto2"; package google.protobuf.internal; +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + // A message where tag numbers are listed out of order, to allow us to test our // canonicalization of serialized output, which should always be in tag order. // We also mix in some extensions for extra fun. @@ -348,3 +351,8 @@ message ConflictJsonName { optional int32 value = 1 [json_name = "old_value"]; optional int32 new_value = 2 [json_name = "value"]; } + +message WKTMessage { + optional Timestamp optional_timestamp = 1; + optional Duration optional_duration = 2; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 869f2aa73177f..5982a84f78d7e 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -27,6 +27,7 @@ __author__ = 'robinson@google.com (Will Robinson)' +import datetime from io import BytesIO import struct import sys @@ -536,13 +537,30 @@ def init(self, **kwargs): self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) - new_val = field_value - if isinstance(field_value, dict): + new_val = None + if isinstance(field_value, message_mod.Message): + new_val = field_value + elif isinstance(field_value, dict): new_val = field.message_type._concrete_class(**field_value) - try: - copy.MergeFrom(new_val) - except TypeError: - _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) + elif field.message_type.full_name == 'google.protobuf.Timestamp': + copy.FromDatetime(field_value) + elif field.message_type.full_name == 'google.protobuf.Duration': + copy.FromTimedelta(field_value) + else: + raise TypeError( + 'Message field {0}.{1} must be initialized with a ' + 'dict or instance of same class, got {2}.'.format( + message_descriptor.name, + field_name, + type(field_value).__name__, + ) + ) + + if new_val: + try: + copy.MergeFrom(new_val) + except TypeError: + _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) self._fields[field] = copy else: if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: @@ -753,8 +771,17 @@ def getter(self): # We define a setter just so we can throw an exception with a more # helpful error message. def setter(self, new_value): - raise AttributeError('Assignment not allowed to composite field ' - '"%s" in protocol message object.' % proto_field_name) + if field.message_type.full_name == 'google.protobuf.Timestamp': + getter(self) + self._fields[field].FromDatetime(new_value) + elif field.message_type.full_name == 'google.protobuf.Duration': + getter(self) + self._fields[field].FromTimedelta(new_value) + else: + raise AttributeError( + 'Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name + ) # Add a property to encapsulate the getter. doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index ea90be3be645a..7d036447cfc4f 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -21,8 +21,8 @@ import collections.abc import datetime import warnings - from google.protobuf.internal import field_mask +from typing import Union FieldMask = field_mask.FieldMask @@ -271,12 +271,35 @@ def FromDatetime(self, dt): # manipulated into a long value of seconds. During the conversion from # struct_time to long, the source date in UTC, and so it follows that the # correct transformation is calendar.timegm() - seconds = calendar.timegm(dt.utctimetuple()) - nanos = dt.microsecond * _NANOS_PER_MICROSECOND + try: + seconds = calendar.timegm(dt.utctimetuple()) + nanos = dt.microsecond * _NANOS_PER_MICROSECOND + except AttributeError as e: + raise AttributeError( + 'Fail to convert to Timestamp. Expected a datetime like ' + 'object got {0} : {1}'.format(type(dt).__name__, e) + ) from e _CheckTimestampValid(seconds, nanos) self.seconds = seconds self.nanos = nanos + def __add__(self, value) -> datetime.datetime: + if isinstance(value, Duration): + return self.ToDatetime() + value.ToTimedelta() + return self.ToDatetime() + value + + __radd__ = __add__ + + def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]: + if isinstance(value, Timestamp): + return self.ToDatetime() - value.ToDatetime() + elif isinstance(value, Duration): + return self.ToDatetime() - value.ToTimedelta() + return self.ToDatetime() - value + + def __rsub__(self, dt) -> datetime.timedelta: + return dt - self.ToDatetime() + def _CheckTimestampValid(seconds, nanos): if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX: @@ -408,8 +431,16 @@ def ToTimedelta(self) -> datetime.timedelta: def FromTimedelta(self, td): """Converts timedelta to Duration.""" - self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, - td.microseconds * _NANOS_PER_MICROSECOND) + try: + self._NormalizeDuration( + td.seconds + td.days * _SECONDS_PER_DAY, + td.microseconds * _NANOS_PER_MICROSECOND, + ) + except AttributeError as e: + raise AttributeError( + 'Fail to convert to Duration. Expected a timedelta like ' + 'object got {0}: {1}'.format(type(td).__name__, e) + ) from e def _NormalizeDuration(self, seconds, nanos): """Set Duration by seconds and nanos.""" @@ -420,6 +451,16 @@ def _NormalizeDuration(self, seconds, nanos): self.seconds = seconds self.nanos = nanos + def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]: + if isinstance(value, Timestamp): + return self.ToTimedelta() + value.ToDatetime() + return self.ToTimedelta() + value + + __radd__ = __add__ + + def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]: + return dt - self.ToTimedelta() + def _CheckDurationValid(seconds, nanos): if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX: diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 592776634deaf..4f1e39dad3a12 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -13,13 +13,15 @@ import datetime import unittest -from google.protobuf import any_pb2 +from google.protobuf import text_format from google.protobuf.internal import any_test_pb2 +from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import well_known_types + +from google.protobuf import any_pb2 from google.protobuf import duration_pb2 from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 -from google.protobuf.internal import well_known_types -from google.protobuf import text_format from google.protobuf.internal import _parameterized from google.protobuf import unittest_pb2 @@ -351,6 +353,123 @@ def testTimezoneAwareMinDatetimeConversion(self): tz_aware_min_datetime, ts.ToDatetime(datetime.timezone.utc) ) + # Two hours after the Unix Epoch, around the world. + @_parameterized.named_parameters( + ('London', [1970, 1, 1, 2], datetime.timezone.utc), + ('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN), + ('LA', [1969, 12, 31, 18], _TZ_PACIFIC), + ) + def testTimestampAssignment(self, date_parts, tzinfo): + original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime + msg = more_messages_pb2.WKTMessage() + msg.optional_timestamp = original_datetime + self.assertEqual(7200, msg.optional_timestamp.seconds) + self.assertEqual(0, msg.optional_timestamp.nanos) + + # Two hours after the Unix Epoch, around the world. + @_parameterized.named_parameters( + ('London', [1970, 1, 1, 2], datetime.timezone.utc), + ('Tokyo', [1970, 1, 1, 11], _TZ_JAPAN), + ('LA', [1969, 12, 31, 18], _TZ_PACIFIC), + ) + def testTimestampCreation(self, date_parts, tzinfo): + original_datetime = datetime.datetime(*date_parts, tzinfo=tzinfo) # pylint:disable=g-tzinfo-datetime + msg = more_messages_pb2.WKTMessage(optional_timestamp=original_datetime) + self.assertEqual(7200, msg.optional_timestamp.seconds) + self.assertEqual(0, msg.optional_timestamp.nanos) + + msg2 = more_messages_pb2.WKTMessage( + optional_timestamp=msg.optional_timestamp + ) + self.assertEqual(7200, msg2.optional_timestamp.seconds) + self.assertEqual(0, msg2.optional_timestamp.nanos) + + @_parameterized.named_parameters( + ( + 'tz_aware_min_dt', + datetime.datetime(1, 1, 1, tzinfo=datetime.timezone.utc), + datetime.timedelta(hours=9), + -62135564400, + 0, + ), + ( + 'no_change', + datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN), + datetime.timedelta(hours=0), + 7200, + 0, + ), + ) + def testTimestampAdd(self, old_time, time_delta, expected_sec, expected_nano): + msg = more_messages_pb2.WKTMessage() + msg.optional_timestamp = old_time + + # Timestamp + timedelta + new_msg1 = more_messages_pb2.WKTMessage() + new_msg1.optional_timestamp = msg.optional_timestamp + time_delta + self.assertEqual(expected_sec, new_msg1.optional_timestamp.seconds) + self.assertEqual(expected_nano, new_msg1.optional_timestamp.nanos) + + # timedelta + Timestamp + new_msg2 = more_messages_pb2.WKTMessage() + new_msg2.optional_timestamp = time_delta + msg.optional_timestamp + self.assertEqual(expected_sec, new_msg2.optional_timestamp.seconds) + self.assertEqual(expected_nano, new_msg2.optional_timestamp.nanos) + + # Timestamp + Duration + msg.optional_duration.FromTimedelta(time_delta) + new_msg3 = more_messages_pb2.WKTMessage() + new_msg3.optional_timestamp = msg.optional_timestamp + msg.optional_duration + self.assertEqual(expected_sec, new_msg3.optional_timestamp.seconds) + self.assertEqual(expected_nano, new_msg3.optional_timestamp.nanos) + + @_parameterized.named_parameters( + ( + 'test1', + datetime.datetime(999, 1, 1, tzinfo=datetime.timezone.utc), + datetime.timedelta(hours=9), + -30641792400, + 0, + ), + ( + 'no_change', + datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN), + datetime.timedelta(hours=0), + 7200, + 0, + ), + ) + def testTimestampSub(self, old_time, time_delta, expected_sec, expected_nano): + msg = more_messages_pb2.WKTMessage() + msg.optional_timestamp = old_time + + # Timestamp - timedelta + new_msg1 = more_messages_pb2.WKTMessage() + new_msg1.optional_timestamp = msg.optional_timestamp - time_delta + self.assertEqual(expected_sec, new_msg1.optional_timestamp.seconds) + self.assertEqual(expected_nano, new_msg1.optional_timestamp.nanos) + + # Timestamp - Duration + msg.optional_duration = time_delta + new_msg2 = more_messages_pb2.WKTMessage() + new_msg2.optional_timestamp = msg.optional_timestamp - msg.optional_duration + self.assertEqual(expected_sec, new_msg2.optional_timestamp.seconds) + self.assertEqual(expected_nano, new_msg2.optional_timestamp.nanos) + + result_msg = more_messages_pb2.WKTMessage() + result_msg.optional_timestamp = old_time - time_delta + # Timestamp - Timestamp + td = msg.optional_timestamp - result_msg.optional_timestamp + self.assertEqual(time_delta, td) + + # Timestamp - datetime + td1 = msg.optional_timestamp - result_msg.optional_timestamp.ToDatetime() + self.assertEqual(time_delta, td1) + + # datetime - Timestamp + td2 = msg.optional_timestamp.ToDatetime() - result_msg.optional_timestamp + self.assertEqual(time_delta, td2) + def testNanosOneSecond(self): tz = _TZ_PACIFIC ts = timestamp_pb2.Timestamp(nanos=1_000_000_000) @@ -413,6 +532,18 @@ def testInvalidTimestamp(self): message.ToJsonString) self.assertRaisesRegex(ValueError, 'Timestamp is not valid', message.FromSeconds, -62135596801) + msg = more_messages_pb2.WKTMessage() + with self.assertRaises(AttributeError): + msg.optional_timestamp = 1 + + with self.assertRaises(AttributeError): + msg2 = more_messages_pb2.WKTMessage(optional_timestamp=1) + + with self.assertRaises(TypeError): + msg.optional_timestamp + '' + + with self.assertRaises(TypeError): + msg.optional_timestamp - 123 def testInvalidDuration(self): message = duration_pb2.Duration() @@ -446,6 +577,105 @@ def testInvalidDuration(self): self.assertRaisesRegex(ValueError, r'Duration is not valid\: Sign mismatch.', message.ToJsonString) + msg = more_messages_pb2.WKTMessage() + with self.assertRaises(AttributeError): + msg.optional_duration = 1 + + with self.assertRaises(AttributeError): + msg2 = more_messages_pb2.WKTMessage(optional_duration=1) + + with self.assertRaises(TypeError): + msg.optional_duration + '' + + with self.assertRaises(TypeError): + 123 - msg.optional_duration + + @_parameterized.named_parameters( + ('test1', -1999999, -1, -999999000), ('test2', 1999999, 1, 999999000) + ) + def testDurationAssignment(self, microseconds, expected_sec, expected_nano): + message = more_messages_pb2.WKTMessage() + expected_td = datetime.timedelta(microseconds=microseconds) + message.optional_duration = expected_td + self.assertEqual(expected_td, message.optional_duration.ToTimedelta()) + self.assertEqual(expected_sec, message.optional_duration.seconds) + self.assertEqual(expected_nano, message.optional_duration.nanos) + + @_parameterized.named_parameters( + ('test1', -1999999, -1, -999999000), ('test2', 1999999, 1, 999999000) + ) + def testDurationCreation(self, microseconds, expected_sec, expected_nano): + message = more_messages_pb2.WKTMessage( + optional_duration=datetime.timedelta(microseconds=microseconds) + ) + expected_td = datetime.timedelta(microseconds=microseconds) + self.assertEqual(expected_td, message.optional_duration.ToTimedelta()) + self.assertEqual(expected_sec, message.optional_duration.seconds) + self.assertEqual(expected_nano, message.optional_duration.nanos) + + @_parameterized.named_parameters( + ( + 'tz_aware_min_dt', + datetime.datetime(1, 1, 1, tzinfo=datetime.timezone.utc), + datetime.timedelta(hours=9), + -62135564400, + 0, + ), + ( + 'no_change', + datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN), + datetime.timedelta(hours=0), + 7200, + 0, + ), + ) + def testDurationAdd(self, old_time, time_delta, expected_sec, expected_nano): + msg = more_messages_pb2.WKTMessage() + msg.optional_duration = time_delta + msg.optional_timestamp = old_time + + # Duration + datetime + msg1 = more_messages_pb2.WKTMessage() + msg1.optional_timestamp = msg.optional_duration + old_time + self.assertEqual(expected_sec, msg1.optional_timestamp.seconds) + self.assertEqual(expected_nano, msg1.optional_timestamp.nanos) + + # datetime + Duration + msg2 = more_messages_pb2.WKTMessage() + msg2.optional_timestamp = old_time + msg.optional_duration + self.assertEqual(expected_sec, msg2.optional_timestamp.seconds) + self.assertEqual(expected_nano, msg2.optional_timestamp.nanos) + + # Duration + Timestamp + msg3 = more_messages_pb2.WKTMessage() + msg3.optional_timestamp = msg.optional_duration + msg.optional_timestamp + self.assertEqual(expected_sec, msg3.optional_timestamp.seconds) + self.assertEqual(expected_nano, msg3.optional_timestamp.nanos) + + @_parameterized.named_parameters( + ( + 'test1', + datetime.datetime(999, 1, 1, tzinfo=datetime.timezone.utc), + datetime.timedelta(hours=9), + -30641792400, + 0, + ), + ( + 'no_change', + datetime.datetime(1970, 1, 1, 11, tzinfo=_TZ_JAPAN), + datetime.timedelta(hours=0), + 7200, + 0, + ), + ) + def testDurationSub(self, old_time, time_delta, expected_sec, expected_nano): + msg = more_messages_pb2.WKTMessage() + msg.optional_duration = time_delta + + # datetime - Duration + msg.optional_timestamp = old_time - msg.optional_duration + self.assertEqual(expected_sec, msg.optional_timestamp.seconds) + self.assertEqual(expected_nano, msg.optional_timestamp.nanos) class StructTest(unittest.TestCase): diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 39fe35a9412aa..5966f29e4a211 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -1092,9 +1092,41 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { return -1; } } else { - ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); - if (merged == nullptr) { - return -1; + if (PyObject_TypeCheck(value, CMessage_Type)) { + ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); + if (merged == nullptr) { + return -1; + } + } else { + switch (descriptor->message_type()->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + AssureWritable(cmessage); + ScopedPyObjectPtr ok( + PyObject_CallMethod(reinterpret_cast(cmessage), + "FromDatetime", "O", value)); + if (ok.get() == nullptr) { + return -1; + } + break; + } + case Descriptor::WELLKNOWNTYPE_DURATION: { + AssureWritable(cmessage); + ScopedPyObjectPtr ok( + PyObject_CallMethod(reinterpret_cast(cmessage), + "FromTimedelta", "O", value)); + if (ok.get() == nullptr) { + return -1; + } + break; + } + default: + PyErr_Format( + PyExc_TypeError, + "Parameter to initialize message field must be " + "dict or instance of same class: expected %s got %s.", + descriptor->full_name().c_str(), Py_TYPE(value)->tp_name); + return -1; + } } } } else { @@ -2561,11 +2593,34 @@ int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, field_descriptor->name().c_str()); return -1; } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyErr_Format(PyExc_AttributeError, - "Assignment not allowed to " - "field \"%s\" in protocol message object.", - field_descriptor->name().c_str()); - return -1; + switch (field_descriptor->message_type()->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { + AssureWritable(self); + PyObject* sub_message = GetFieldValue(self, field_descriptor); + ScopedPyObjectPtr ok( + PyObject_CallMethod(sub_message, "FromDatetime", "O", value)); + if (ok.get() == nullptr) { + return -1; + } + return 0; + } + case Descriptor::WELLKNOWNTYPE_DURATION: { + AssureWritable(self); + PyObject* sub_message = GetFieldValue(self, field_descriptor); + ScopedPyObjectPtr ok( + PyObject_CallMethod(sub_message, "FromTimedelta", "O", value)); + if (ok.get() == nullptr) { + return -1; + } + return 0; + } + default: + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } } else { AssureWritable(self); return InternalSetScalar(self, field_descriptor, value); diff --git a/python/message.c b/python/message.c index 42bf90236d035..f31aa5e7342f5 100644 --- a/python/message.c +++ b/python/message.c @@ -433,6 +433,7 @@ static bool PyUpb_Message_InitRepeatedAttribute(PyObject* _self, PyObject* name, } static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name, + const upb_FieldDef* field, PyObject* value) { PyObject* submsg = PyUpb_Message_GetAttr(_self, name); if (!submsg) return -1; @@ -446,10 +447,24 @@ static bool PyUpb_Message_InitMessageAttribute(PyObject* _self, PyObject* name, assert(!PyErr_Occurred()); ok = PyUpb_Message_InitAttributes(submsg, NULL, value) >= 0; } else { - const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self); - PyErr_Format(PyExc_TypeError, "Message must be initialized with a dict: %s", - upb_MessageDef_FullName(m)); - ok = false; + const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field); + switch (upb_MessageDef_WellKnownType(msgdef)) { + case kUpb_WellKnown_Timestamp: { + ok = PyObject_CallMethod(submsg, "FromDatetime", "O", value); + break; + } + case kUpb_WellKnown_Duration: { + ok = PyObject_CallMethod(submsg, "FromTimedelta", "O", value); + break; + } + default: { + const upb_MessageDef* m = PyUpb_Message_GetMsgdef(_self); + PyErr_Format(PyExc_TypeError, + "Message must be initialized with a dict: %s", + upb_MessageDef_FullName(m)); + ok = false; + } + } } Py_DECREF(submsg); return ok; @@ -502,7 +517,7 @@ int PyUpb_Message_InitAttributes(PyObject* _self, PyObject* args, } else if (upb_FieldDef_IsRepeated(f)) { if (!PyUpb_Message_InitRepeatedAttribute(_self, name, value)) return -1; } else if (upb_FieldDef_IsSubMessage(f)) { - if (!PyUpb_Message_InitMessageAttribute(_self, name, value)) return -1; + if (!PyUpb_Message_InitMessageAttribute(_self, name, f, value)) return -1; } else { if (!PyUpb_Message_InitScalarAttribute(msg, f, value, arena)) return -1; } @@ -935,9 +950,9 @@ int PyUpb_Message_SetFieldValue(PyObject* _self, const upb_FieldDef* field, PyUpb_Message* self = (void*)_self; assert(value); - if (upb_FieldDef_IsSubMessage(field) || upb_FieldDef_IsRepeated(field)) { + if (upb_FieldDef_IsRepeated(field)) { PyErr_Format(exc, - "Assignment not allowed to message, map, or repeated " + "Assignment not allowed to map, or repeated " "field \"%s\" in protocol message object.", upb_FieldDef_Name(field)); return -1; @@ -945,6 +960,34 @@ int PyUpb_Message_SetFieldValue(PyObject* _self, const upb_FieldDef* field, PyUpb_Message_EnsureReified(self); + if (upb_FieldDef_IsSubMessage(field)) { + const upb_MessageDef* msgdef = upb_FieldDef_MessageSubDef(field); + switch (upb_MessageDef_WellKnownType(msgdef)) { + case kUpb_WellKnown_Timestamp: { + PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field); + PyObject* ok = + PyObject_CallMethod(sub_message, "FromDatetime", "O", value); + if (!ok) return -1; + Py_DECREF(ok); + return 0; + } + case kUpb_WellKnown_Duration: { + PyObject* sub_message = PyUpb_Message_GetFieldValue(_self, field); + PyObject* ok = + PyObject_CallMethod(sub_message, "FromTimedelta", "O", value); + if (!ok) return -1; + Py_DECREF(ok); + return 0; + } + default: + PyErr_Format(exc, + "Assignment not allowed to message " + "field \"%s\" in protocol message object.", + upb_FieldDef_Name(field)); + return -1; + } + } + upb_MessageValue val; upb_Arena* arena = PyUpb_Arena_Get(self->arena); if (!PyUpb_PyToUpb(value, field, &val, arena)) {