From 3a9f0743ea8d82f489a65f7d087fa01d26ac5f56 Mon Sep 17 00:00:00 2001 From: Jie Luo Date: Wed, 29 May 2024 12:33:23 -0700 Subject: [PATCH] =?UTF-8?q?Nextgen=20Proto=20Pythonic=20API:=20=20?= =?UTF-8?q?=E2=80=9CAdd-on=E2=80=9D=20proto=20for=20length=20prefixed=20se?= =?UTF-8?q?rialize/parse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added the following methods: serialize_length_prefixed(message: Message, output: io.BytesIO) -> None parse_length_prefixed(message_class: Type[Message], input_bytes: io.BytesIO) -> Message The output of serialize_length_prefixed should be BytesIO or custom buffered IO that data should be written to. The output stream must be buffered, e.g. using https://docs.python.org/3/library/io.html#buffered-streams. PiperOrigin-RevId: 638375900 --- python/build_targets.bzl | 5 ++ python/google/protobuf/internal/decoder.py | 22 +++-- .../google/protobuf/internal/decoder_test.py | 57 +++++++++++++ python/google/protobuf/internal/proto_test.py | 85 ++++++++++++++++++- python/google/protobuf/proto.py | 83 +++++++++++++++++- 5 files changed, 243 insertions(+), 9 deletions(-) create mode 100644 python/google/protobuf/internal/decoder_test.py diff --git a/python/build_targets.bzl b/python/build_targets.bzl index 0cc2911ac09ea..ee765ab4c60be 100644 --- a/python/build_targets.bzl +++ b/python/build_targets.bzl @@ -411,6 +411,11 @@ def build_targets(name): srcs = ["google/protobuf/internal/well_known_types_test.py"], ) + internal_py_test( + name = "decoder_test", + srcs = ["google/protobuf/internal/decoder_test.py"], + ) + internal_py_test( name = "wire_format_test", srcs = ["google/protobuf/internal/wire_format_test.py"], diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index ddaced2c6dc03..dcde1d9420c9a 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -60,10 +60,10 @@ import math import struct +from google.protobuf import message from google.protobuf.internal import containers from google.protobuf.internal import encoder from google.protobuf.internal import wire_format -from google.protobuf import message # This is not for optimization, but rather to avoid conflicts with local @@ -81,20 +81,32 @@ def _VarintDecoder(mask, result_type): decoder returns a (value, new_pos) pair. """ - def DecodeVarint(buffer, pos): + def DecodeVarint(buffer, pos: int=None): result = 0 shift = 0 while 1: - b = buffer[pos] + if pos is None: + # Read from BytesIO + try: + b = buffer.read(1)[0] + except IndexError as e: + if shift == 0: + # End of BytesIO. + return None + else: + raise ValueError('Fail to read varint %s' % str(e)) + else: + b = buffer[pos] + pos += 1 result |= ((b & 0x7f) << shift) - pos += 1 if not (b & 0x80): result &= mask result = result_type(result) - return (result, pos) + return result if pos is None else (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py new file mode 100644 index 0000000000000..f801b6e76fd8b --- /dev/null +++ b/python/google/protobuf/internal/decoder_test.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Test decoder.""" + +import io +import unittest + +from google.protobuf.internal import decoder +from google.protobuf.internal import testing_refleaks + + +_INPUT_BYTES = b'\x84r\x12' +_EXPECTED = (14596, 18) + + +@testing_refleaks.TestCase +class DecoderTest(unittest.TestCase): + + def test_decode_varint_bytes(self): + (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0) + self.assertEqual(size, _EXPECTED[0]) + self.assertEqual(pos, 2) + + (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2) + self.assertEqual(size, _EXPECTED[1]) + self.assertEqual(pos, 3) + + def test_decode_varint_bytes_empty(self): + with self.assertRaises(IndexError) as context: + (size, pos) = decoder._DecodeVarint(b'', 0) + self.assertIn('index out of range', str(context.exception)) + + def test_decode_varint_bytesio(self): + index = 0 + input_io = io.BytesIO(_INPUT_BYTES) + while True: + size = decoder._DecodeVarint(input_io) + if size is None: + break + self.assertEqual(size, _EXPECTED[index]) + index += 1 + self.assertEqual(index, len(_EXPECTED)) + + def test_decode_varint_bytesio_empty(self): + input_io = io.BytesIO(b'') + size = decoder._DecodeVarint(input_io) + self.assertEqual(size, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/proto_test.py b/python/google/protobuf/internal/proto_test.py index 71ee406cd5d85..b3351a498fa42 100644 --- a/python/google/protobuf/internal/proto_test.py +++ b/python/google/protobuf/internal/proto_test.py @@ -8,16 +8,19 @@ """Tests Nextgen Pythonic protobuf APIs.""" +import io import unittest from google.protobuf import proto - +from google.protobuf.internal import encoder from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks + from google.protobuf.internal import _parameterized from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 + @_parameterized.named_parameters(('_proto2', unittest_pb2), ('_proto3', unittest_proto3_arena_pb2)) @testing_refleaks.TestCase @@ -30,6 +33,86 @@ def test_simple_serialize_parse(self, message_module): parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data) self.assertEqual(msg, parsed_msg) + def test_serialize_parse_length_prefixed_empty(self, message_module): + empty_alltypes = message_module.TestAllTypes() + out = io.BytesIO() + proto.serialize_length_prefixed(empty_alltypes, out) + + input_bytes = io.BytesIO(out.getvalue()) + msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes) + + self.assertEqual(msg, empty_alltypes) + + def test_parse_length_prefixed_truncated(self, message_module): + out = io.BytesIO() + encoder._VarintEncoder()(out.write, 9999) + msg = message_module.TestAllTypes(optional_int32=1) + out.write(proto.serialize(msg)) + + input_bytes = io.BytesIO(out.getvalue()) + with self.assertRaises(ValueError) as context: + proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes) + self.assertEqual( + str(context.exception), + 'Truncated message or non-buffered input_bytes: ' + 'Expected 9999 bytes but only 2 bytes parsed for ' + 'TestAllTypes.', + ) + + def test_serialize_length_prefixed_fake_io(self, message_module): + class FakeBytesIO(io.BytesIO): + + def write(self, b: bytes) -> int: + return 0 + + msg = message_module.TestAllTypes(optional_int32=123) + out = FakeBytesIO() + with self.assertRaises(TypeError) as context: + proto.serialize_length_prefixed(msg, out) + self.assertIn( + 'Failed to write complete message (wrote: 0, expected: 2)', + str(context.exception), + ) + + +_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi' +_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi' + + +@_parameterized.named_parameters( + ('_proto2', unittest_pb2, _EXPECTED_PROTO2), + ('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3), +) +@testing_refleaks.TestCase +class LengthPrefixedWithGolden(unittest.TestCase): + + def test_serialize_length_prefixed(self, message_module, expected): + number_of_messages = 3 + + out = io.BytesIO() + for index in range(0, number_of_messages): + msg = message_module.TestAllTypes( + optional_int32=index, optional_string='hi' + ) + proto.serialize_length_prefixed(msg, out) + + self.assertEqual(out.getvalue(), expected) + + def test_parse_length_prefixed(self, message_module, input_bytes): + expected_number_of_messages = 3 + + input_io = io.BytesIO(input_bytes) + index = 0 + while True: + msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io) + if msg is None: + break + self.assertEqual(msg.optional_int32, index) + self.assertEqual(msg.optional_string, 'hi') + index += 1 + + self.assertEqual(index, expected_number_of_messages) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/proto.py b/python/google/protobuf/proto.py index 722bbb23fc04a..df0a0d02d98c8 100644 --- a/python/google/protobuf/proto.py +++ b/python/google/protobuf/proto.py @@ -7,11 +7,17 @@ """Contains the Nextgen Pythonic protobuf APIs.""" -import typing +import io +from typing import Type, TypeVar +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder from google.protobuf.message import Message -def serialize(message: Message, deterministic: bool=None) -> bytes: +_MESSAGE = TypeVar('_MESSAGE', bound='Message') + + +def serialize(message: _MESSAGE, deterministic: bool = None) -> bytes: """Return the serialized proto. Args: @@ -24,7 +30,8 @@ def serialize(message: Message, deterministic: bool=None) -> bytes: """ return message.SerializeToString(deterministic=deterministic) -def parse(message_class: typing.Type[Message], payload: bytes) -> Message: + +def parse(message_class: Type[_MESSAGE], payload: bytes) -> _MESSAGE: """Given a serialized data in binary form, deserialize it into a Message. Args: @@ -37,3 +44,73 @@ def parse(message_class: typing.Type[Message], payload: bytes) -> Message: new_message = message_class() new_message.ParseFromString(payload) return new_message + + +def serialize_length_prefixed(message: _MESSAGE, output: io.BytesIO) -> None: + """Writes the size of the message as a varint and the serialized message. + + Writes the size of the message as a varint and then the serialized message. + This allows more data to be written to the output after the message. Use + parse_length_prefixed to parse messages written by this method. + + The output stream must be buffered, e.g. using + https://docs.python.org/3/library/io.html#buffered-streams. + + Example usage: + out = io.BytesIO() + for msg in message_list: + proto.serialize_length_prefixed(msg, out) + + Args: + message: The protocol buffer message that should be serialized. + output: BytesIO or custom buffered IO that data should be written to. + """ + size = message.ByteSize() + encoder._VarintEncoder()(output.write, size) + out_size = output.write(serialize(message)) + + if out_size != size: + raise TypeError( + 'Failed to write complete message (wrote: %d, expected: %d)' + '. Ensure output is using buffered IO.' % (out_size, size) + ) + + +def parse_length_prefixed( + message_class: Type[_MESSAGE], input_bytes: io.BytesIO +) -> _MESSAGE: + """Parse a message from input_bytes. + + Args: + message_class: The protocol buffer message class that parser should parse. + input_bytes: A buffered input. + + Example usage: + while True: + msg = proto.parse_length_prefixed(message_class, input_bytes) + if msg is None: + break + ... + + Returns: + A parsed message if successful. None if input_bytes is at EOF. + """ + size = decoder._DecodeVarint(input_bytes) + if size is None: + # It is the end of buffered input. See example usage in the + # API description. + return None + + message = message_class() + + if size == 0: + return message + + parsed_size = message.ParseFromString(input_bytes.read(size)) + if parsed_size != size: + raise ValueError( + 'Truncated message or non-buffered input_bytes: ' + 'Expected {0} bytes but only {1} bytes parsed for ' + '{2}.'.format(size, parsed_size, message.DESCRIPTOR.name) + ) + return message