Skip to content

Commit

Permalink
WIP tagged union message type API
Browse files Browse the repository at this point in the history
  • Loading branch information
goodboy committed Jul 6, 2022
1 parent c71785c commit c7288a6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 22 deletions.
53 changes: 34 additions & 19 deletions tractor/_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import typing
from collections.abc import AsyncGenerator, AsyncIterator
from typing import (
Any, Tuple, Optional,
Any, Optional,
Type, Protocol, TypeVar,
)

Expand All @@ -34,6 +34,7 @@
from async_generator import asynccontextmanager

from .log import get_logger
from .msg import get_msg_codecs
from ._exceptions import TransportClosed
log = get_logger(__name__)

Expand All @@ -42,7 +43,13 @@
log = get_logger(__name__)


def get_stream_addrs(stream: trio.SocketStream) -> Tuple:
def get_stream_addrs(
stream: trio.SocketStream,

) -> tuple[
tuple[str, int],
tuple[str, int],
]:
# should both be IP sockets
lsockname = stream.socket.getsockname()
rsockname = stream.socket.getpeername()
Expand Down Expand Up @@ -87,11 +94,11 @@ def drain(self) -> AsyncIterator[dict]:
...

@property
def laddr(self) -> Tuple[str, int]:
def laddr(self) -> tuple[str, int]:
...

@property
def raddr(self) -> Tuple[str, int]:
def raddr(self) -> tuple[str, int]:
...


Expand Down Expand Up @@ -167,11 +174,11 @@ async def _iter_packets(self) -> AsyncGenerator[dict, None]:
yield packet

@property
def laddr(self) -> Tuple[Any, ...]:
def laddr(self) -> tuple[Any, ...]:
return self._laddr

@property
def raddr(self) -> Tuple[Any, ...]:
def raddr(self) -> tuple[Any, ...]:
return self._raddr

async def send(self, msg: Any) -> None:
Expand Down Expand Up @@ -216,18 +223,21 @@ def __init__(
prefix_size: int = 4,

) -> None:
import msgspec

super().__init__(stream)
self.recv_stream = BufferedReceiveStream(transport_stream=stream)
self.recv_stream = BufferedReceiveStream(
transport_stream=stream
)
self.prefix_size = prefix_size

# TODO: struct aware messaging coders
self.encode = msgspec.msgpack.Encoder().encode
self.decode = msgspec.msgpack.Decoder().decode # dict[str, Any])
enc, dec = get_msg_codecs()
self.encode = enc.encode
self.decode = dec.decode # dict[str, Any])

async def _iter_packets(self) -> AsyncGenerator[dict, None]:
'''Yield packets from the underlying stream.
'''
Yield packets from the underlying stream.
'''
import msgspec # noqa
Expand Down Expand Up @@ -268,6 +278,11 @@ async def _iter_packets(self) -> AsyncGenerator[dict, None]:
msgspec.DecodeError,
UnicodeDecodeError,
):
log.error(
'`msgspec` failed to decode!?\n'
'dumping bytes:\n'
f'{msg_bytes}'
)
if decodes_failed < 4:
# ignore decoding errors for now and assume they have to
# do with a channel drop - hope that receiving from the
Expand Down Expand Up @@ -300,7 +315,7 @@ async def send(self, msg: Any) -> None:

def get_msg_transport(

key: Tuple[str, str],
key: tuple[str, str],

) -> Type[MsgTransport]:

Expand All @@ -322,9 +337,9 @@ class Channel:
def __init__(

self,
destaddr: Optional[Tuple[str, int]],
destaddr: Optional[tuple[str, int]],

msg_transport_type_key: Tuple[str, str] = ('msgpack', 'tcp'),
msg_transport_type_key: tuple[str, str] = ('msgpack', 'tcp'),

# TODO: optional reconnection support?
# auto_reconnect: bool = False,
Expand Down Expand Up @@ -352,7 +367,7 @@ def __init__(
self.msgstream: Optional[MsgTransport] = None

# set after handshake - always uid of far end
self.uid: Optional[Tuple[str, str]] = None
self.uid: Optional[tuple[str, str]] = None

self._agen = self._aiter_recv()
self._exc: Optional[Exception] = None # set if far end actor errors
Expand Down Expand Up @@ -380,7 +395,7 @@ def from_stream(
def set_msg_transport(
self,
stream: trio.SocketStream,
type_key: Optional[Tuple[str, str]] = None,
type_key: Optional[tuple[str, str]] = None,

) -> MsgTransport:
type_key = type_key or self._transport_key
Expand All @@ -395,16 +410,16 @@ def __repr__(self) -> str:
return object.__repr__(self)

@property
def laddr(self) -> Optional[Tuple[str, int]]:
def laddr(self) -> Optional[tuple[str, int]]:
return self.msgstream.laddr if self.msgstream else None

@property
def raddr(self) -> Optional[Tuple[str, int]]:
def raddr(self) -> Optional[tuple[str, int]]:
return self.msgstream.raddr if self.msgstream else None

async def connect(
self,
destaddr: Tuple[Any, ...] = None,
destaddr: tuple[Any, ...] = None,
**kwargs

) -> MsgTransport:
Expand Down
77 changes: 74 additions & 3 deletions tractor/msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
# ``importlib.import_module()`` which can be filtered by inserting
# a ``MetaPathFinder`` into ``sys.meta_path`` (which we could do before
# entering the ``Actor._process_messages()`` loop).
# - https://github.com/python/cpython/blob/main/Lib/pkgutil.py#L645
# - https://stackoverflow.com/questions/1350466/preventing-python-code-from-importing-certain-modules
# https://github.com/python/cpython/blob/main/Lib/pkgutil.py#L645
# https://stackoverflow.com/questions/1350466/preventing-python-code-from-importing-certain-modules
# - https://stackoverflow.com/a/63320902
# - https://docs.python.org/3/library/sys.html#sys.meta_path

Expand All @@ -40,10 +40,19 @@
# - https://jcristharif.com/msgspec/api.html#struct
# - https://jcristharif.com/msgspec/extending.html
# via ``msgpack-python``:
# - https://github.com/msgpack/msgpack-python#packingunpacking-of-custom-data-type
# https://github.com/msgpack/msgpack-python#packingunpacking-of-custom-data-type

from __future__ import annotations
from contextlib import contextmanager as cm
from pkgutil import resolve_name
from typing import Union, Any


from msgspec import Struct
from msgspec.msgpack import (
Encoder,
Decoder,
)


class NamespacePath(str):
Expand Down Expand Up @@ -78,3 +87,65 @@ def from_ref(
(ref.__module__,
getattr(ref, '__name__', ''))
))


# LIFO codec stack that is appended when the user opens the
# ``configure_native_msgs()`` cm below to configure a new codec set
# which will be applied to all new (msgspec relevant) IPC transports
# that are spawned **after** the configure call is made.
_lifo_codecs: list[
tuple[
Encoder,
Decoder,
],
] = [(Encoder(), Decoder())]


def get_msg_codecs() -> tuple[
Encoder,
Decoder,
]:
'''
Return the currently configured ``msgspec`` codec set.
The defaults are defined above.
'''
global _lifo_codecs
return _lifo_codecs[-1]


@cm
def configure_native_msgs(
tagged_structs: list[Struct],
):
'''
Push a codec set that will natively decode
tagged structs provied in ``tagged_structs``
in all IPC transports and pop the codec on exit.
'''
global _lifo_codecs

# See "tagged unions" docs:
# https://jcristharif.com/msgspec/structs.html#tagged-unions

# "The quickest way to enable tagged unions is to set tag=True when
# defining every struct type in the union. In this case tag_field
# defaults to "type", and tag defaults to the struct class name
# (e.g. "Get")."
enc = Encoder()

types_union = Union[tagged_structs[0]] | Any
for struct in tagged_structs[1:]:
types_union |= struct

dec = Decoder(types_union)

_lifo_codecs.append((enc, dec))
try:
print("YOYOYOOYOYOYOY")
yield enc, dec
finally:
print("NONONONONON")
_lifo_codecs.pop()

0 comments on commit c7288a6

Please sign in to comment.