From 4422d1cb90c71514716ad33d6b526a794b8e2eed Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 13 Sep 2024 20:21:04 +0800 Subject: [PATCH 01/10] Binary IDL With MessagePack Bytes Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f5d81b0636..4566dd261e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,5 +1,6 @@ from __future__ import annotations +import msgpack import collections import copy import dataclasses @@ -27,6 +28,8 @@ from google.protobuf.struct_pb2 import Struct from mashumaro.codecs.json import JSONDecoder, JSONEncoder from mashumaro.mixins.json import DataClassJSONMixin +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder + from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -50,6 +53,7 @@ Scalar, Union, Void, +Binary ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType @@ -191,6 +195,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + if binary_idl_object.tag == "msgpack": + decoder = MessagePackDecoder(expected_python_type) + return decoder.decode(binary_idl_object.value) + else: + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -1643,17 +1653,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return None, None @staticmethod - def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: + def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: """ Creates a flyte-specific ``Literal`` value from a native python dictionary. """ from flytekit.types.pickle import FlytePickle try: - return Literal( - scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())), - metadata={"format": "json"}, - ) + msgpack_bytes = msgpack.dumps(v) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) except TypeError as e: if allow_pickle: remote_path = FlytePickle.to_pickle(ctx, v) @@ -1663,7 +1671,8 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L ), metadata={"format": "pickle"}, ) - raise e + raise TypeTransformerFailedError(f"Cannot convert from {v} to Flyte Literal.\n" + f"Error Message: {e}") @staticmethod def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: @@ -1714,7 +1723,7 @@ def to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_generic_literal(ctx, python_val, allow_pickle) + return self.dict_to_binary_literal(ctx, python_val, allow_pickle) lit_map = {} for k, v in python_val.items(): @@ -1731,6 +1740,9 @@ def to_literal( return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + if lv and lv.map and lv.map.literals is not None: tp = self.dict_types(expected_python_type) From 0c756d9f86890cdc1712064569f9423aeebafa7c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 13 Sep 2024 20:25:15 +0800 Subject: [PATCH 02/10] support simpleTransformer from_binary_idl Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4566dd261e..f1490f123b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -251,6 +251,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Cannot convert to type {expected_python_type}, only {self._type} is supported" ) + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal res = self._from_literal_transformer(lv) if type(res) != self._type: From c0830717449c36cab750ae8c0ac0b4c19549ef63 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 13 Sep 2024 21:50:53 +0800 Subject: [PATCH 03/10] Binary IDL With MessagePack Bytes Signed-off-by: Future-Outlier --- flytekit/core/promise.py | 42 +++++++++++++++------ flytekit/core/type_engine.py | 71 ++++++++++++++++++++++++------------ 2 files changed, 78 insertions(+), 35 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 9a8a853981..48dd39cccf 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -41,7 +41,7 @@ from flytekit.models import types as _type_models from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model -from flytekit.models.literals import Primitive +from flytekit.models.literals import Binary, Literal, Primitive, Scalar from flytekit.models.task import Resources from flytekit.models.types import SimpleType @@ -138,21 +138,41 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: break # If the current value is a dataclass, resolve the dataclass with the remaining path - if ( - len(p.attr_path) > 0 - and type(curr_val.value) is _literals_models.Scalar - and type(curr_val.value.value) is _struct.Struct - ): - st = curr_val.value.value - new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) - literal_type = TypeEngine.to_literal_type(type(new_st)) - # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) - curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) + if len(p.attr_path) > 0 and type(curr_val.value) is _literals_models.Scalar: + # We keep it for reference task local execution in the future. + if type(curr_val.value.value) is _struct.Struct: + st = curr_val.value.value + new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) + literal_type = TypeEngine.to_literal_type(type(new_st)) + # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) + curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) + elif type(curr_val.value.value) is Binary: + binary_idl_obj = curr_val.value.value + if binary_idl_obj.tag == "msgpack": + import msgpack + + dict_obj = msgpack.loads(binary_idl_obj.value) + v = resolve_attr_path_in_dict(dict_obj, attr_path=p.attr_path[used:]) + msgpack_bytes = msgpack.dumps(v) + curr_val = Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) p._val = curr_val return p +def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any: + curr_val = d + for attr in attr_path: + try: + curr_val = curr_val[attr] + except (KeyError, IndexError, TypeError) as e: + raise FlytePromiseAttributeResolveException( + f"Failed to resolve attribute path {attr_path} in dict {curr_val}, attribute {attr} not found.\n" + f"Error Message: {e}" + ) + return curr_val + + def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct: curr_val = st for attr in attr_path: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f1490f123b..c7ce527245 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,6 +1,5 @@ from __future__ import annotations -import msgpack import collections import copy import dataclasses @@ -18,6 +17,7 @@ from functools import lru_cache from typing import Dict, List, NamedTuple, Optional, Type, cast +import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -27,9 +27,8 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from mashumaro.codecs.json import JSONDecoder, JSONEncoder -from mashumaro.mixins.json import DataClassJSONMixin from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder - +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -45,16 +44,7 @@ from flytekit.models import types as _type_models from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.core import types as _core_types -from flytekit.models.literals import ( - Literal, - LiteralCollection, - LiteralMap, - Primitive, - Scalar, - Union, - Void, -Binary -) +from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType T = typing.TypeVar("T") @@ -133,6 +123,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): self._t = t self._name = name self._type_assertions_enabled = enable_type_assertions + self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {} + self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {} @property def name(self): @@ -197,10 +189,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: if binary_idl_object.tag == "msgpack": - decoder = MessagePackDecoder(expected_python_type) + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type) + self._msgpack_decoder[expected_python_type] = decoder return decoder.decode(binary_idl_object.value) else: raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -252,7 +249,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) if lv.scalar and lv.scalar.binary: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal res = self._from_literal_transformer(lv) @@ -509,8 +506,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if isinstance(python_val, dict): - json_str = json.dumps(python_val) - return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + msgpack_bytes = msgpack.dumps(python_val) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) if not dataclasses.is_dataclass(python_val): raise TypeTransformerFailedError( @@ -525,17 +522,19 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp # We can't use hasattr(python_val, "to_json") here because we rely on mashumaro's API to customize the serialization behavior for Flyte types. if isinstance(python_val, DataClassJSONMixin): json_str = python_val.to_json() + dict_obj = json.loads(json_str) + msgpack_bytes = msgpack.dumps(dict_obj) else: # The function looks up or creates a JSONEncoder specifically designed for the object's type. # This encoder is then used to convert a data class into a JSON string. try: - encoder = self._encoder[python_type] + encoder = self._msgpack_encoder[python_type] except KeyError: - encoder = JSONEncoder(python_type) + encoder = MessagePackEncoder(python_type) self._encoder[python_type] = encoder try: - json_str = encoder.encode(python_val) + msgpack_bytes = encoder.encode(python_val) except NotImplementedError: # you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. raise NotImplementedError( @@ -543,7 +542,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f" and implement _serialize and _deserialize methods." ) - return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is @@ -682,6 +681,25 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An return dc + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T: + if binary_idl_object.tag == "msgpack": + if issubclass(expected_python_type, DataClassJSONMixin): + dict_obj = msgpack.loads(binary_idl_object.value) + json_str = json.dumps(dict_obj) + dc = expected_python_type.from_json(json_str) # type: ignore + else: + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type) + self._msgpack_decoder[expected_python_type] = decoder + dc = decoder.decode(binary_idl_object.value) + + dc = self._fix_structured_dataset_type(expected_python_type, dc) + return self._fix_dataclass_int(expected_python_type, dc) + else: + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: if not dataclasses.is_dataclass(expected_python_type): raise TypeTransformerFailedError( @@ -689,6 +707,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: "user defined datatypes in Flytekit" ) + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + json_str = _json_format.MessageToJson(lv.scalar.generic) # The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. @@ -1365,6 +1386,9 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + try: lits = lv.collection.literals except AttributeError: @@ -1674,8 +1698,7 @@ def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Li ), metadata={"format": "pickle"}, ) - raise TypeTransformerFailedError(f"Cannot convert from {v} to Flyte Literal.\n" - f"Error Message: {e}") + raise TypeTransformerFailedError(f"Cannot convert from {v} to Flyte Literal.\n" f"Error Message: {e}") @staticmethod def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: @@ -1744,7 +1767,7 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: if lv and lv.scalar and lv.scalar.binary is not None: - return self.from_binary_idl(lv.scalar.binary, expected_python_type) + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore if lv and lv.map and lv.map.literals is not None: tp = self.dict_types(expected_python_type) From 4ff299a5e0fad80621a2dd31042e8eee7091bb3b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 13 Sep 2024 22:44:27 +0800 Subject: [PATCH 04/10] Solve Strict type problem by assign default decoder to mashumaro API Signed-off-by: Future-Outlier --- flytekit/core/promise.py | 2 +- flytekit/core/type_engine.py | 17 ++++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 48dd39cccf..568f7e5f68 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -151,7 +151,7 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: if binary_idl_obj.tag == "msgpack": import msgpack - dict_obj = msgpack.loads(binary_idl_obj.value) + dict_obj = msgpack.loads(binary_idl_obj.value, strict_map_key=False) v = resolve_attr_path_in_dict(dict_obj, attr_path=p.attr_path[used:]) msgpack_bytes = msgpack.dumps(v) curr_val = Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index c7ce527245..451ac3e255 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from functools import lru_cache -from typing import Dict, List, NamedTuple, Optional, Type, cast +from typing import Any, Dict, List, NamedTuple, Optional, Type, cast import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json @@ -52,6 +52,13 @@ TITLE = "title" +# In mashumaro, default encoder use strict type = False, but default decoder use strict type = True. +# This is for case like Dict[int, str], where the key is int, but it's serialized as string. +# If we don't use strict_map_key=False, the decoder will raise error for strict types. +def _default_flytekit_decoder(data: bytes) -> Any: + return msgpack.unpackb(data, raw=False, strict_map_key=False) + + class BatchSize: """ This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, @@ -192,7 +199,7 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ try: decoder = self._msgpack_decoder[expected_python_type] except KeyError: - decoder = MessagePackDecoder(expected_python_type) + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) self._msgpack_decoder[expected_python_type] = decoder return decoder.decode(binary_idl_object.value) else: @@ -531,7 +538,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp encoder = self._msgpack_encoder[python_type] except KeyError: encoder = MessagePackEncoder(python_type) - self._encoder[python_type] = encoder + self._msgpack_encoder[python_type] = encoder try: msgpack_bytes = encoder.encode(python_val) @@ -684,14 +691,14 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T: if binary_idl_object.tag == "msgpack": if issubclass(expected_python_type, DataClassJSONMixin): - dict_obj = msgpack.loads(binary_idl_object.value) + dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False) json_str = json.dumps(dict_obj) dc = expected_python_type.from_json(json_str) # type: ignore else: try: decoder = self._msgpack_decoder[expected_python_type] except KeyError: - decoder = MessagePackDecoder(expected_python_type) + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) self._msgpack_decoder[expected_python_type] = decoder dc = decoder.decode(binary_idl_object.value) From 6a991a455f6c65cfa409d4009226326b80125bf7 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 16 Sep 2024 10:33:57 +0800 Subject: [PATCH 05/10] add msgpack to pyproject Signed-off-by: Future-Outlier --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 8e8fcef90f..9dec5d73fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.11", + "msgpack>=1.0.8", "protobuf!=4.25.0", "pygments", "python-json-logger>=2.0.0", From 8873507050d2d3d1b05024a3caca2854a1ee12c7 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 16 Sep 2024 21:55:59 +0800 Subject: [PATCH 06/10] add notes Signed-off-by: Future-Outlier --- flytekit/core/promise.py | 4 ++++ flytekit/core/type_engine.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 568f7e5f68..c6ee6ee2aa 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -231,9 +231,13 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._op = op self._lhs = None self._rhs = None + import flytekit if isinstance(lhs, Promise): self._lhs = lhs if lhs.is_ready: + # if lhs.val.scalar and lhs.val.scalar.binary: + # primitive_val = TypeEngine.to_python_value(flytekit.current_context(), lhs.val, ) + # lhs.val.scalar.primitive = if lhs.val.scalar is None or lhs.val.scalar.primitive is None: union = lhs.val.scalar.union if union and union.value.scalar: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9699202a6a..d484a1010d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -58,7 +58,6 @@ def _default_flytekit_decoder(data: bytes) -> Any: return msgpack.unpackb(data, raw=False, strict_map_key=False) - class BatchSize: """ This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, From d89ed08b145c8b4b73be1ab37d9284e7736c4c96 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Mon, 16 Sep 2024 23:46:30 +0800 Subject: [PATCH 07/10] nit Signed-off-by: Future-Outlier --- flytekit/core/promise.py | 5 +---- flytekit/core/type_engine.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index c6ee6ee2aa..4f0c923b5e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -231,13 +231,10 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._op = op self._lhs = None self._rhs = None - import flytekit + if isinstance(lhs, Promise): self._lhs = lhs if lhs.is_ready: - # if lhs.val.scalar and lhs.val.scalar.binary: - # primitive_val = TypeEngine.to_python_value(flytekit.current_context(), lhs.val, ) - # lhs.val.scalar.primitive = if lhs.val.scalar is None or lhs.val.scalar.primitive is None: union = lhs.val.scalar.union if union and union.value.scalar: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d484a1010d..9699202a6a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -58,6 +58,7 @@ def _default_flytekit_decoder(data: bytes) -> Any: return msgpack.unpackb(data, raw=False, strict_map_key=False) + class BatchSize: """ This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, From 442b2bf7a7c221777afa37793fd598e4e1ea4ca9 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 17 Sep 2024 00:19:16 +0800 Subject: [PATCH 08/10] support flytefile Signed-off-by: Future-Outlier --- flytekit/types/file/file.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index ba6af4a7dd..9b317797b8 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -9,6 +9,7 @@ from typing import cast from urllib.parse import unquote +import msgpack from dataclasses_json import config from marshmallow import fields from mashumaro.mixins.json import DataClassJSONMixin @@ -20,7 +21,7 @@ from flytekit.loggers import logger from flytekit.models.core import types as _core_types from flytekit.models.core.types import BlobType -from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar, Binary from flytekit.models.types import LiteralType from flytekit.types.pickle.pickle import FlytePickleTransformer @@ -518,9 +519,37 @@ def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, s return {"ContentEncoding": "gzip"} return {} + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: typing.Type[T]) -> typing.Optional[T]: + if binary_idl_object.tag == "msgpack": + python_val = msgpack.loads(binary_idl_object.value) + path = python_val.get("path", None) + if path is None: + raise ValueError("FlyteFile's path should not be None") + + return FlyteFilePathTransformer().to_python_value( + FlyteContextManager.current_context(), + Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata( + type=_core_types.BlobType( + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ), + uri=path, + ) + ) + ), + expected_python_type, + ) + else: + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) + try: uri = lv.scalar.blob.uri except AttributeError: From 05a58794c984ff7da4a6a0c40f626d659a7fb8db Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 17 Sep 2024 14:18:59 +0800 Subject: [PATCH 09/10] better comments for default annotation, and don't use self._fix_dataclass_int in binary idl in dataclass transformer Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 9699202a6a..d662bb6a5d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -52,9 +52,9 @@ TITLE = "title" -# In mashumaro, default encoder use strict type = False, but default decoder use strict type = True. -# This is for case like Dict[int, str], where the key is int, but it's serialized as string. -# If we don't use strict_map_key=False, the decoder will raise error for strict types. +# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. +# This is relevant for cases like Dict[int, str], where the key is an int, it's not supported when strict_map_key=False. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode strict types. def _default_flytekit_decoder(data: bytes) -> Any: return msgpack.unpackb(data, raw=False, strict_map_key=False) @@ -732,8 +732,7 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ self._msgpack_decoder[expected_python_type] = decoder dc = decoder.decode(binary_idl_object.value) - dc = self._fix_structured_dataset_type(expected_python_type, dc) - return self._fix_dataclass_int(expected_python_type, dc) + return self._fix_structured_dataset_type(expected_python_type, dc) else: raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") From 59851685214fb7bf9c9a9fb9c405b0a032f1c0f6 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 17 Sep 2024 14:35:59 +0800 Subject: [PATCH 10/10] improve decoder's comments Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d662bb6a5d..4baba05456 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -53,8 +53,8 @@ # In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. -# This is relevant for cases like Dict[int, str], where the key is an int, it's not supported when strict_map_key=False. -# If strict_map_key=False is not used, the decoder will raise an error when trying to decode strict types. +# This is relevant for cases like Dict[int, str]. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` def _default_flytekit_decoder(data: bytes) -> Any: return msgpack.unpackb(data, raw=False, strict_map_key=False)