Skip to content

Commit

Permalink
Nextgen Proto Pythonic API: “Add-on” proto for length prefixed serial…
Browse files Browse the repository at this point in the history
…ize/parse

Added the following methods:
serialize_length_prefixed(message: Message, output: io.BytesIO) -> None
parse_length_prefixed(message_class: Type[Message], input_bytes: io.BytesIO) -> Message

The output of serialize_length_prefixed should be BytesIO or custom buffered IO that data should be written to. The output stream must be buffered, e.g. using
  https://docs.python.org/3/library/io.html#buffered-streams.

PiperOrigin-RevId: 638375900
  • Loading branch information
anandolee authored and copybara-github committed May 29, 2024
1 parent fdc7f65 commit 3a9f074
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 9 deletions.
5 changes: 5 additions & 0 deletions python/build_targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def build_targets(name):
srcs = ["google/protobuf/internal/well_known_types_test.py"],
)

internal_py_test(
name = "decoder_test",
srcs = ["google/protobuf/internal/decoder_test.py"],
)

internal_py_test(
name = "wire_format_test",
srcs = ["google/protobuf/internal/wire_format_test.py"],
Expand Down
22 changes: 17 additions & 5 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@
import math
import struct

from google.protobuf import message
from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message


# This is not for optimization, but rather to avoid conflicts with local
Expand All @@ -81,20 +81,32 @@ def _VarintDecoder(mask, result_type):
decoder returns a (value, new_pos) pair.
"""

def DecodeVarint(buffer, pos):
def DecodeVarint(buffer, pos: int=None):
result = 0
shift = 0
while 1:
b = buffer[pos]
if pos is None:
# Read from BytesIO
try:
b = buffer.read(1)[0]
except IndexError as e:
if shift == 0:
# End of BytesIO.
return None
else:
raise ValueError('Fail to read varint %s' % str(e))
else:
b = buffer[pos]
pos += 1
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
result = result_type(result)
return (result, pos)
return result if pos is None else (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')

return DecodeVarint


Expand Down
57 changes: 57 additions & 0 deletions python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

"""Test decoder."""

import io
import unittest

from google.protobuf.internal import decoder
from google.protobuf.internal import testing_refleaks


_INPUT_BYTES = b'\x84r\x12'
_EXPECTED = (14596, 18)


@testing_refleaks.TestCase
class DecoderTest(unittest.TestCase):

def test_decode_varint_bytes(self):
(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
self.assertEqual(size, _EXPECTED[0])
self.assertEqual(pos, 2)

(size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2)
self.assertEqual(size, _EXPECTED[1])
self.assertEqual(pos, 3)

def test_decode_varint_bytes_empty(self):
with self.assertRaises(IndexError) as context:
(size, pos) = decoder._DecodeVarint(b'', 0)
self.assertIn('index out of range', str(context.exception))

def test_decode_varint_bytesio(self):
index = 0
input_io = io.BytesIO(_INPUT_BYTES)
while True:
size = decoder._DecodeVarint(input_io)
if size is None:
break
self.assertEqual(size, _EXPECTED[index])
index += 1
self.assertEqual(index, len(_EXPECTED))

def test_decode_varint_bytesio_empty(self):
input_io = io.BytesIO(b'')
size = decoder._DecodeVarint(input_io)
self.assertEqual(size, None)


if __name__ == '__main__':
unittest.main()
85 changes: 84 additions & 1 deletion python/google/protobuf/internal/proto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@

"""Tests Nextgen Pythonic protobuf APIs."""

import io
import unittest

from google.protobuf import proto

from google.protobuf.internal import encoder
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks

from google.protobuf.internal import _parameterized
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2


@_parameterized.named_parameters(('_proto2', unittest_pb2),
('_proto3', unittest_proto3_arena_pb2))
@testing_refleaks.TestCase
Expand All @@ -30,6 +33,86 @@ def test_simple_serialize_parse(self, message_module):
parsed_msg = proto.parse(message_module.TestAllTypes, serialized_data)
self.assertEqual(msg, parsed_msg)

def test_serialize_parse_length_prefixed_empty(self, message_module):
empty_alltypes = message_module.TestAllTypes()
out = io.BytesIO()
proto.serialize_length_prefixed(empty_alltypes, out)

input_bytes = io.BytesIO(out.getvalue())
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)

self.assertEqual(msg, empty_alltypes)

def test_parse_length_prefixed_truncated(self, message_module):
out = io.BytesIO()
encoder._VarintEncoder()(out.write, 9999)
msg = message_module.TestAllTypes(optional_int32=1)
out.write(proto.serialize(msg))

input_bytes = io.BytesIO(out.getvalue())
with self.assertRaises(ValueError) as context:
proto.parse_length_prefixed(message_module.TestAllTypes, input_bytes)
self.assertEqual(
str(context.exception),
'Truncated message or non-buffered input_bytes: '
'Expected 9999 bytes but only 2 bytes parsed for '
'TestAllTypes.',
)

def test_serialize_length_prefixed_fake_io(self, message_module):
class FakeBytesIO(io.BytesIO):

def write(self, b: bytes) -> int:
return 0

msg = message_module.TestAllTypes(optional_int32=123)
out = FakeBytesIO()
with self.assertRaises(TypeError) as context:
proto.serialize_length_prefixed(msg, out)
self.assertIn(
'Failed to write complete message (wrote: 0, expected: 2)',
str(context.exception),
)


_EXPECTED_PROTO3 = b'\x04r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'
_EXPECTED_PROTO2 = b'\x06\x08\x00r\x02hi\x06\x08\x01r\x02hi\x06\x08\x02r\x02hi'


@_parameterized.named_parameters(
('_proto2', unittest_pb2, _EXPECTED_PROTO2),
('_proto3', unittest_proto3_arena_pb2, _EXPECTED_PROTO3),
)
@testing_refleaks.TestCase
class LengthPrefixedWithGolden(unittest.TestCase):

def test_serialize_length_prefixed(self, message_module, expected):
number_of_messages = 3

out = io.BytesIO()
for index in range(0, number_of_messages):
msg = message_module.TestAllTypes(
optional_int32=index, optional_string='hi'
)
proto.serialize_length_prefixed(msg, out)

self.assertEqual(out.getvalue(), expected)

def test_parse_length_prefixed(self, message_module, input_bytes):
expected_number_of_messages = 3

input_io = io.BytesIO(input_bytes)
index = 0
while True:
msg = proto.parse_length_prefixed(message_module.TestAllTypes, input_io)
if msg is None:
break
self.assertEqual(msg.optional_int32, index)
self.assertEqual(msg.optional_string, 'hi')
index += 1

self.assertEqual(index, expected_number_of_messages)


if __name__ == '__main__':
unittest.main()
83 changes: 80 additions & 3 deletions python/google/protobuf/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@

"""Contains the Nextgen Pythonic protobuf APIs."""

import typing
import io
from typing import Type, TypeVar

from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.message import Message

def serialize(message: Message, deterministic: bool=None) -> bytes:
_MESSAGE = TypeVar('_MESSAGE', bound='Message')


def serialize(message: _MESSAGE, deterministic: bool = None) -> bytes:
"""Return the serialized proto.
Args:
Expand All @@ -24,7 +30,8 @@ def serialize(message: Message, deterministic: bool=None) -> bytes:
"""
return message.SerializeToString(deterministic=deterministic)

def parse(message_class: typing.Type[Message], payload: bytes) -> Message:

def parse(message_class: Type[_MESSAGE], payload: bytes) -> _MESSAGE:
"""Given a serialized data in binary form, deserialize it into a Message.
Args:
Expand All @@ -37,3 +44,73 @@ def parse(message_class: typing.Type[Message], payload: bytes) -> Message:
new_message = message_class()
new_message.ParseFromString(payload)
return new_message


def serialize_length_prefixed(message: _MESSAGE, output: io.BytesIO) -> None:
"""Writes the size of the message as a varint and the serialized message.
Writes the size of the message as a varint and then the serialized message.
This allows more data to be written to the output after the message. Use
parse_length_prefixed to parse messages written by this method.
The output stream must be buffered, e.g. using
https://docs.python.org/3/library/io.html#buffered-streams.
Example usage:
out = io.BytesIO()
for msg in message_list:
proto.serialize_length_prefixed(msg, out)
Args:
message: The protocol buffer message that should be serialized.
output: BytesIO or custom buffered IO that data should be written to.
"""
size = message.ByteSize()
encoder._VarintEncoder()(output.write, size)
out_size = output.write(serialize(message))

if out_size != size:
raise TypeError(
'Failed to write complete message (wrote: %d, expected: %d)'
'. Ensure output is using buffered IO.' % (out_size, size)
)


def parse_length_prefixed(
message_class: Type[_MESSAGE], input_bytes: io.BytesIO
) -> _MESSAGE:
"""Parse a message from input_bytes.
Args:
message_class: The protocol buffer message class that parser should parse.
input_bytes: A buffered input.
Example usage:
while True:
msg = proto.parse_length_prefixed(message_class, input_bytes)
if msg is None:
break
...
Returns:
A parsed message if successful. None if input_bytes is at EOF.
"""
size = decoder._DecodeVarint(input_bytes)
if size is None:
# It is the end of buffered input. See example usage in the
# API description.
return None

message = message_class()

if size == 0:
return message

parsed_size = message.ParseFromString(input_bytes.read(size))
if parsed_size != size:
raise ValueError(
'Truncated message or non-buffered input_bytes: '
'Expected {0} bytes but only {1} bytes parsed for '
'{2}.'.format(size, parsed_size, message.DESCRIPTOR.name)
)
return message

0 comments on commit 3a9f074

Please sign in to comment.