Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flytekit][1][SimpleTransformer] Binary IDL With MessagePack #2756

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand All @@ -26,6 +27,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -42,22 +44,21 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Union,
Void,
)
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function doesn't have anything to do with Mashumaro right? why mention mashumaro in the comments?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we want to use it in our mashumaro's decoder.

Copy link
Member Author

@Future-Outlier Future-Outlier Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will be put into _default_flytekit_decoder.
For example, if we access dataclasss.dict_int_str in a workflow, then we will use from_binary_idl here to turn the Binary IDL object to Dict[int, str].

Note:

  1. dict_int_str is Dict[int, str].
  2. Dict[int, str] is a non-strict type

# This is relevant for cases like Dict[int, str].
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.`
def _default_flytekit_decoder(data: bytes) -> Any:
return msgpack.unpackb(data, raw=False, strict_map_key=False)


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {}
Future-Outlier marked this conversation as resolved.
Show resolved Hide resolved
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {}

@property
def name(self):
Expand Down Expand Up @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
if binary_idl_object.tag == "msgpack":
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part isn't clear to me... why are we using the MessagePackDecoder from Mashumaro? isn't that just for dataclasses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part isn't clear to me... why are we using the MessagePackDecoder from Mashumaro? isn't that just for dataclasses?

We use MessagePackDecoder because when accessing attributes from a dataclass, we receive a Binary IDL from propeller, which can be any type (int, float, bool, str, list, dict, dataclass, Pydantic BaseModel, or Flyte types).

The expected flow is: Binary IDL -> msgpack bytes -> python val.

MessagePackDecoder[expected_python_type].decode is more reliable than msgpack.dumps because it guarantees the type is always correct.
(It has expected_python_type as a hint, and it can handle cases like
expected type: float, actual type: int, and convert it to float back.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a puristic pov what you're saying is true, @wild-endeavor, but given that mashumaro is already a dependency and also given the fact that msgpack.loads is unable to unmarshal some values, I'm in favor of leaving this more complex implementation of the top-level from_binary_idl.

Simple cases like this fail with using msgpack.loads:

Python 3.12.5 (main, Aug 14 2024, 04:32:18) [Clang 18.1.8 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> from datetime import datetime
>>> from mashumaro.codecs.msgpack import MessagePackEncoder
>>> encoder = MessagePackEncoder(type(datetime.now()))
>>> encoder.encode(datetime.now())
b'\xba2024-09-24T21:52:43.704551'
>>> msgpack.dumps(datetime.now())
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/eduardo/repos/flyte-examples/.venv/lib/python3.12/site-packages/msgpack/__init__.py", line 36, in packb
    return Packer(**kwargs).pack(o)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "msgpack/_packer.pyx", line 279, in msgpack._cmsgpack.Packer.pack
  File "msgpack/_packer.pyx", line 276, in msgpack._cmsgpack.Packer.pack
  File "msgpack/_packer.pyx", line 270, in msgpack._cmsgpack.Packer._pack
  File "msgpack/_packer.pyx", line 257, in msgpack._cmsgpack.Packer._pack_inner
TypeError: can not serialize 'datetime.datetime' object

self._msgpack_decoder[expected_python_type] = decoder
return decoder.decode(binary_idl_object.value)
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")

def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str:
"""
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
Expand Down Expand Up @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal
res = self._from_literal_transformer(lv)
if type(res) != self._type:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"msgpack>=1.1.0",
"protobuf!=4.25.0",
"pygments",
"python-json-logger>=2.0.0",
Expand Down
78 changes: 78 additions & 0 deletions tests/flytekit/unit/core/test_type_engine_binary_idl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from datetime import datetime, date, timedelta
from mashumaro.codecs.msgpack import MessagePackEncoder

from flytekit.models.literals import Binary, Literal, Scalar
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine

def test_simple_type_transformer():
ctx = FlyteContextManager.current_context()

int_inputs = [1, 2, 20240918, -1, -2, -20240918]
encoder = MessagePackEncoder(int)
for int_input in int_inputs:
int_msgpack_bytes = encoder.encode(int_input)
lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack")))
int_output = TypeEngine.to_python_value(ctx, lv, int)
assert int_input == int_output

float_inputs = [2024.0918, 5.0, -2024.0918, -5.0]
encoder = MessagePackEncoder(float)
for float_input in float_inputs:
float_msgpack_bytes = encoder.encode(float_input)
lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack")))
float_output = TypeEngine.to_python_value(ctx, lv, float)
assert float_input == float_output

bool_inputs = [True, False]
encoder = MessagePackEncoder(bool)
for bool_input in bool_inputs:
bool_msgpack_bytes = encoder.encode(bool_input)
lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack")))
bool_output = TypeEngine.to_python_value(ctx, lv, bool)
assert bool_input == bool_output

str_inputs = ["hello", "world", "flyte", "kit", "is", "awesome"]
encoder = MessagePackEncoder(str)
for str_input in str_inputs:
str_msgpack_bytes = encoder.encode(str_input)
lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack")))
str_output = TypeEngine.to_python_value(ctx, lv, str)
assert str_input == str_output

datetime_inputs = [datetime.now(),
datetime(2024, 9, 18),
datetime(2024, 9, 18, 1),
datetime(2024, 9, 18, 1, 1),
datetime(2024, 9, 18, 1, 1, 1),
datetime(2024, 9, 18, 1, 1, 1, 1)]
encoder = MessagePackEncoder(datetime)
for datetime_input in datetime_inputs:
datetime_msgpack_bytes = encoder.encode(datetime_input)
lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack")))
datetime_output = TypeEngine.to_python_value(ctx, lv, datetime)
assert datetime_input == datetime_output

date_inputs = [date.today(),
date(2024, 9, 18)]
encoder = MessagePackEncoder(date)
for date_input in date_inputs:
date_msgpack_bytes = encoder.encode(date_input)
lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack")))
date_output = TypeEngine.to_python_value(ctx, lv, date)
assert date_input == date_output

timedelta_inputs = [timedelta(days=1),
timedelta(days=1, seconds=1),
timedelta(days=1, seconds=1, microseconds=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1),
timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1),
timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)]
encoder = MessagePackEncoder(timedelta)
for timedelta_input in timedelta_inputs:
timedelta_msgpack_bytes = encoder.encode(timedelta_input)
lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack")))
timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta)
assert timedelta_input == timedelta_output
Loading