Skip to content

Commit

Permalink
Add streaming decode support for msgspec
Browse files Browse the repository at this point in the history
Add a `tractor._ipc.MsgspecStream` type which can be swapped in for
`msgspec` serialization transparently. A small msg-length-prefix framing
is implemented as part of the type and we use
`tricycle.BufferedReceieveStream` to handle buffering logic for the
underlying transport.

Notes:
- had to force cast a few more list  -> tuple spots due to no native
  `tuple`decode-by-default in `msgspec`: jcrist/msgspec#30
- the framing can be understood by this protobuf walkthrough:
  https://eli.thegreenplace.net/2011/08/02/length-prefix-framing-for-protocol-buffers
- `tricycle` becomes a new dependency
  • Loading branch information
goodboy committed Jun 30, 2021
1 parent 240351a commit 97f44e2
Showing 1 changed file with 78 additions and 19 deletions.
97 changes: 78 additions & 19 deletions tractor/_ipc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""
Inter-process comms abstractions
"""
import typing
from typing import Any, Tuple, Optional, Callable
from functools import partial
import struct
import typing
from typing import Any, Tuple, Optional

from tricycle import BufferedReceiveStream
import msgpack
import msgspec
import trio
from async_generator import asynccontextmanager

from .log import get_logger
log = get_logger('ipc')
log = get_logger(__name__)

# :eyeroll:
try:
Expand All @@ -22,21 +24,14 @@
Unpacker = partial(msgpack.Unpacker, strict_map_key=False)


ms_decode = msgspec.Encoder().encode


class MsgpackStream:
"""A ``trio.SocketStream`` delivering ``msgpack`` formatted data.
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgpack-python``.
"""
'''
def __init__(
self,
stream: trio.SocketStream,
serialize: Callable = Unpacker(
raw=False,
use_list=False,
).feed,
deserialize: Callable = msgpack.dumps,

) -> None:

Expand All @@ -62,7 +57,6 @@ async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
raw=False,
use_list=False,
)
# decoder = msgspec.Decoder() #dict[str, Any])
while True:
try:
data = await self.stream.receive_some(2**10)
Expand All @@ -75,7 +69,6 @@ async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
log.debug(f"Stream connection {self.raddr} was closed")
return

# yield decoder.decode(data)
unpacker.feed(data)
for packet in unpacker:
yield packet
Expand All @@ -92,8 +85,7 @@ def raddr(self) -> Tuple[Any, ...]:
async def send(self, data: Any) -> None:
async with self._send_lock:
return await self.stream.send_all(
# msgpack.dumps(data, use_bin_type=True))
ms_decode(data)
msgpack.dumps(data, use_bin_type=True)
)

async def recv(self) -> Any:
Expand All @@ -106,26 +98,93 @@ def connected(self) -> bool:
return self.stream.socket.fileno() != -1


class MsgspecStream(MsgpackStream):
'''A ``trio.SocketStream`` delivering ``msgpack`` formatted data
using ``msgspec``.
'''
ms_encode = msgspec.Encoder().encode

def __init__(
self,
stream: trio.SocketStream,
prefix_size: int = 4,

) -> None:
super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.prefix_size = prefix_size

async def _iter_packets(self) -> typing.AsyncGenerator[dict, None]:
"""Yield packets from the underlying stream.
"""
decoder = msgspec.Decoder() # dict[str, Any])

while True:
try:
header = await self.recv_stream.receive_exactly(4)
if header is None:
continue

if header == b'':
log.debug(f"Stream connection {self.raddr} was closed")
return

size, = struct.unpack("<I", header)

log.trace(f'received header {size}')

msg_bytes = await self.recv_stream.receive_exactly(size)

# the value error here is to catch a connect with immediate
# disconnect that will cause an EOF error inside `tricycle`.
except (ValueError, trio.BrokenResourceError):
log.warning(f"Stream connection {self.raddr} broke")
return

log.trace(f"received {msg_bytes}") # type: ignore
yield decoder.decode(msg_bytes)

async def send(self, data: Any) -> None:
async with self._send_lock:

bytes_data = self.ms_encode(data)

# supposedly the fastest says,
# https://stackoverflow.com/a/54027962
size: int = struct.pack("<I", len(bytes_data))

return await self.stream.send_all(size + bytes_data)


class Channel:
"""An inter-process channel for communication between (remote) actors.
Currently the only supported transport is a ``trio.SocketStream``.
"""
def __init__(

self,
destaddr: Optional[Tuple[str, int]] = None,
on_reconnect: typing.Callable[..., typing.Awaitable] = None,
auto_reconnect: bool = False,
stream: trio.SocketStream = None, # expected to be active
# stream_serializer: type = MsgpackStream,
stream_serializer_type: type = MsgspecStream,

) -> None:

self._recon_seq = on_reconnect
self._autorecon = auto_reconnect
self.msgstream: Optional[MsgpackStream] = MsgpackStream(
self.stream_serializer_type = stream_serializer_type
self.msgstream: Optional[type] = stream_serializer_type(
stream) if stream else None

if self.msgstream and destaddr:
raise ValueError(
f"A stream was provided with local addr {self.laddr}"
)

self._destaddr = self.msgstream.raddr if self.msgstream else destaddr
# set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None
Expand Down Expand Up @@ -157,7 +216,7 @@ async def connect(
destaddr = destaddr or self._destaddr
assert isinstance(destaddr, tuple)
stream = await trio.open_tcp_stream(*destaddr, **kwargs)
self.msgstream = MsgpackStream(stream)
self.msgstream = self.stream_serializer_type(stream)
return stream

async def send(self, item: Any) -> None:
Expand Down

0 comments on commit 97f44e2

Please sign in to comment.