-
Notifications
You must be signed in to change notification settings - Fork 300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flytekit][1][SimpleTransformer] Binary IDL With MessagePack #2756
Changes from 3 commits
e3a258a
3562f0c
f93b441
5539c1e
bcaf573
1bac074
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,9 @@ | |
from abc import ABC, abstractmethod | ||
from collections import OrderedDict | ||
from functools import lru_cache | ||
from typing import Dict, List, NamedTuple, Optional, Type, cast | ||
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast | ||
|
||
import msgpack | ||
from dataclasses_json import DataClassJsonMixin, dataclass_json | ||
from flyteidl.core import literals_pb2 | ||
from google.protobuf import json_format as _json_format | ||
|
@@ -26,6 +27,7 @@ | |
from google.protobuf.message import Message | ||
from google.protobuf.struct_pb2 import Struct | ||
from mashumaro.codecs.json import JSONDecoder, JSONEncoder | ||
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder | ||
from mashumaro.mixins.json import DataClassJSONMixin | ||
from typing_extensions import Annotated, get_args, get_origin | ||
|
||
|
@@ -42,22 +44,21 @@ | |
from flytekit.models import types as _type_models | ||
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.literals import ( | ||
Literal, | ||
LiteralCollection, | ||
LiteralMap, | ||
Primitive, | ||
Scalar, | ||
Union, | ||
Void, | ||
) | ||
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void | ||
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType | ||
|
||
T = typing.TypeVar("T") | ||
DEFINITIONS = "definitions" | ||
TITLE = "title" | ||
|
||
|
||
# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. | ||
# This is relevant for cases like Dict[int, str]. | ||
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` | ||
def _default_flytekit_decoder(data: bytes) -> Any: | ||
return msgpack.unpackb(data, raw=False, strict_map_key=False) | ||
|
||
|
||
class BatchSize: | ||
""" | ||
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, | ||
|
@@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): | |
self._t = t | ||
self._name = name | ||
self._type_assertions_enabled = enable_type_assertions | ||
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {} | ||
Future-Outlier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {} | ||
|
||
@property | ||
def name(self): | ||
|
@@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
f"Conversion to python value expected type {expected_python_type} from literal not implemented" | ||
) | ||
|
||
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: | ||
if binary_idl_object.tag == "msgpack": | ||
try: | ||
decoder = self._msgpack_decoder[expected_python_type] | ||
except KeyError: | ||
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part isn't clear to me... why are we using the MessagePackDecoder from Mashumaro? isn't that just for dataclasses? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We use The expected flow is:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From a puristic pov what you're saying is true, @wild-endeavor, but given that mashumaro is already a dependency and also given the fact that Simple cases like this fail with using
|
||
self._msgpack_decoder[expected_python_type] = decoder | ||
return decoder.decode(binary_idl_object.value) | ||
else: | ||
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") | ||
|
||
def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: | ||
""" | ||
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div | ||
|
@@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
f"Cannot convert to type {expected_python_type}, only {self._type} is supported" | ||
) | ||
|
||
if lv.scalar and lv.scalar.binary: | ||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||
|
||
try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal | ||
res = self._from_literal_transformer(lv) | ||
if type(res) != self._type: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from datetime import datetime, date, timedelta | ||
from mashumaro.codecs.msgpack import MessagePackEncoder | ||
|
||
from flytekit.models.literals import Binary, Literal, Scalar | ||
from flytekit.core.context_manager import FlyteContextManager | ||
from flytekit.core.type_engine import TypeEngine | ||
|
||
def test_simple_type_transformer(): | ||
ctx = FlyteContextManager.current_context() | ||
|
||
int_input = 20240918 | ||
encoder = MessagePackEncoder(int) | ||
int_msgpack_bytes = encoder.encode(int_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) | ||
int_output = TypeEngine.to_python_value(ctx, lv, int) | ||
assert int_input == int_output | ||
|
||
float_input = 2024.0918 | ||
encoder = MessagePackEncoder(float) | ||
float_msgpack_bytes = encoder.encode(float_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) | ||
float_output = TypeEngine.to_python_value(ctx, lv, float) | ||
assert float_input == float_output | ||
|
||
bool_input = True | ||
encoder = MessagePackEncoder(bool) | ||
bool_msgpack_bytes = encoder.encode(bool_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) | ||
bool_output = TypeEngine.to_python_value(ctx, lv, bool) | ||
assert bool_input == bool_output | ||
|
||
bool_input = False | ||
bool_msgpack_bytes = encoder.encode(bool_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) | ||
bool_output = TypeEngine.to_python_value(ctx, lv, bool) | ||
assert bool_input == bool_output | ||
|
||
str_input = "hello" | ||
encoder = MessagePackEncoder(str) | ||
str_msgpack_bytes = encoder.encode(str_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) | ||
str_output = TypeEngine.to_python_value(ctx, lv, str) | ||
assert str_input == str_output | ||
|
||
datetime_input = datetime.now() | ||
encoder = MessagePackEncoder(datetime) | ||
datetime_msgpack_bytes = encoder.encode(datetime_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) | ||
datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) | ||
assert datetime_input == datetime_output | ||
|
||
date_input = date.today() | ||
encoder = MessagePackEncoder(date) | ||
date_msgpack_bytes = encoder.encode(date_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) | ||
date_output = TypeEngine.to_python_value(ctx, lv, date) | ||
assert date_input == date_output | ||
|
||
timedelta_input = timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1) | ||
encoder = MessagePackEncoder(timedelta) | ||
timedelta_msgpack_bytes = encoder.encode(timedelta_input) | ||
lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) | ||
timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) | ||
assert timedelta_input == timedelta_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function doesn't have anything to do with Mashumaro right? why mention mashumaro in the comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we want to use it in our
mashumaro's
decoder.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will be put into
_default_flytekit_decoder
.For example, if we access
dataclasss.dict_int_str
in a workflow, then we will usefrom_binary_idl
here to turn the Binary IDL object toDict[int, str]
.Note:
dict_int_str
isDict[int, str]
.Dict[int, str]
is a non-strict type