Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Binary IDL With MessagePack Bytes #2751

Closed
wants to merge 11 commits into from
43 changes: 32 additions & 11 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from flytekit.models import types as _type_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType

Expand Down Expand Up @@ -138,21 +138,41 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if (
len(p.attr_path) > 0
and type(curr_val.value) is _literals_models.Scalar
and type(curr_val.value.value) is _struct.Struct
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar:
# We keep it for reference task local execution in the future.
if type(curr_val.value.value) is _struct.Struct:
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
elif type(curr_val.value.value) is Binary:
binary_idl_obj = curr_val.value.value
if binary_idl_obj.tag == "msgpack":
import msgpack

dict_obj = msgpack.loads(binary_idl_obj.value, strict_map_key=False)
v = resolve_attr_path_in_dict(dict_obj, attr_path=p.attr_path[used:])
msgpack_bytes = msgpack.dumps(v)
curr_val = Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

p._val = curr_val
return p


def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
curr_val = d
for attr in attr_path:
try:
curr_val = curr_val[attr]
except (KeyError, IndexError, TypeError) as e:
raise FlytePromiseAttributeResolveException(
f"Failed to resolve attribute path {attr_path} in dict {curr_val}, attribute {attr} not found.\n"
f"Error Message: {e}"
)
return curr_val


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
curr_val = st
for attr in attr_path:
Expand Down Expand Up @@ -211,6 +231,7 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr
self._op = op
self._lhs = None
self._rhs = None

if isinstance(lhs, Promise):
self._lhs = lhs
if lhs.is_ready:
Expand Down
92 changes: 68 additions & 24 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand All @@ -26,6 +27,7 @@
from google.protobuf.message import Message
from google.protobuf.struct_pb2 import Struct
from mashumaro.codecs.json import JSONDecoder, JSONEncoder
from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder
from mashumaro.mixins.json import DataClassJSONMixin
from typing_extensions import Annotated, get_args, get_origin

Expand All @@ -42,22 +44,21 @@
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
from flytekit.models.literals import (
Literal,
LiteralCollection,
LiteralMap,
Primitive,
Scalar,
Union,
Void,
)
from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
TITLE = "title"


# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True.
# This is relevant for cases like Dict[int, str].
# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.`
def _default_flytekit_decoder(data: bytes) -> Any:
return msgpack.unpackb(data, raw=False, strict_map_key=False)


class BatchSize:
"""
This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example,
Expand Down Expand Up @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True):
self._t = t
self._name = name
self._type_assertions_enabled = enable_type_assertions
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {}
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {}

@property
def name(self):
Expand Down Expand Up @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Conversion to python value expected type {expected_python_type} from literal not implemented"
)

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]:
if binary_idl_object.tag == "msgpack":
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
self._msgpack_decoder[expected_python_type] = decoder
return decoder.decode(binary_idl_object.value)
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")

def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str:
"""
Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div
Expand Down Expand Up @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"Cannot convert to type {expected_python_type}, only {self._type} is supported"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal
res = self._from_literal_transformer(lv)
if type(res) != self._type:
Expand Down Expand Up @@ -526,8 +543,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
if isinstance(python_val, dict):
json_str = json.dumps(python_val)
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct())))
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
Expand All @@ -542,25 +559,27 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
# We can't use hasattr(python_val, "to_json") here because we rely on mashumaro's API to customize the serialization behavior for Flyte types.
if isinstance(python_val, DataClassJSONMixin):
json_str = python_val.to_json()
dict_obj = json.loads(json_str)
msgpack_bytes = msgpack.dumps(dict_obj)
else:
# The function looks up or creates a JSONEncoder specifically designed for the object's type.
# This encoder is then used to convert a data class into a JSON string.
try:
encoder = self._encoder[python_type]
encoder = self._msgpack_encoder[python_type]
except KeyError:
encoder = JSONEncoder(python_type)
self._encoder[python_type] = encoder
encoder = MessagePackEncoder(python_type)
self._msgpack_encoder[python_type] = encoder

try:
json_str = encoder.encode(python_val)
msgpack_bytes = encoder.encode(python_val)
except NotImplementedError:
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented.
raise NotImplementedError(
f"{python_type} should inherit from mashumaro.types.SerializableType"
f" and implement _serialize and _deserialize methods."
)

return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))

def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
# dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is
Expand Down Expand Up @@ -699,13 +718,34 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An

return dc

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
if binary_idl_object.tag == "msgpack":
if issubclass(expected_python_type, DataClassJSONMixin):
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore
else:
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder)
self._msgpack_decoder[expected_python_type] = decoder
dc = decoder.decode(binary_idl_object.value)

return self._fix_structured_dataset_type(expected_python_type, dc)
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if not dataclasses.is_dataclass(expected_python_type):
raise TypeTransformerFailedError(
f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for "
"user defined datatypes in Flytekit"
)

if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

json_str = _json_format.MessageToJson(lv.scalar.generic)

# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
Expand Down Expand Up @@ -1382,6 +1422,9 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

try:
lits = lv.collection.literals
except AttributeError:
Expand Down Expand Up @@ -1689,17 +1732,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
return None, None

@staticmethod
def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal:
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
msgpack_bytes = msgpack.dumps(v)
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack")))
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
Expand All @@ -1709,7 +1750,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L
),
metadata={"format": "pickle"},
)
raise e
raise TypeTransformerFailedError(f"Cannot convert from {v} to Flyte Literal.\n" f"Error Message: {e}")

@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
Expand Down Expand Up @@ -1760,7 +1801,7 @@ def to_literal(
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(ctx, python_val, allow_pickle)
return self.dict_to_binary_literal(ctx, python_val, allow_pickle)

lit_map = {}
for k, v in python_val.items():
Expand All @@ -1777,6 +1818,9 @@ def to_literal(
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.scalar and lv.scalar.binary is not None:
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore

if lv and lv.map and lv.map.literals is not None:
tp = self.dict_types(expected_python_type)

Expand Down
31 changes: 30 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast
from urllib.parse import unquote

import msgpack
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
Expand All @@ -20,7 +21,7 @@
from flytekit.loggers import logger
from flytekit.models.core import types as _core_types
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Binary
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickleTransformer

Expand Down Expand Up @@ -518,9 +519,37 @@ def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, s
return {"ContentEncoding": "gzip"}
return {}

def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: typing.Type[T]) -> typing.Optional[T]:
if binary_idl_object.tag == "msgpack":
python_val = msgpack.loads(binary_idl_object.value)
path = python_val.get("path", None)
if path is None:
raise ValueError("FlyteFile's path should not be None")

return FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
blob=Blob(
metadata=BlobMetadata(
type=_core_types.BlobType(
format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
uri=path,
)
)
),
expected_python_type,
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}")
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
if lv.scalar and lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)

try:
uri = lv.scalar.blob.uri
except AttributeError:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"msgpack>=1.0.8",
"protobuf!=4.25.0",
"pygments",
"python-json-logger>=2.0.0",
Expand Down
Loading