diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 861909eedd..dafee55a4b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,8 +15,9 @@ from abc import ABC, abstractmethod from collections import OrderedDict from functools import lru_cache -from typing import Dict, List, NamedTuple, Optional, Type, cast +from typing import Any, Dict, List, NamedTuple, Optional, Type, cast +import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -26,6 +27,7 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from mashumaro.codecs.json import JSONDecoder, JSONEncoder +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin @@ -42,15 +44,7 @@ from flytekit.models import types as _type_models from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.core import types as _core_types -from flytekit.models.literals import ( - Literal, - LiteralCollection, - LiteralMap, - Primitive, - Scalar, - Union, - Void, -) +from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType T = typing.TypeVar("T") @@ -58,6 +52,13 @@ TITLE = "title" +# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. +# This is relevant for cases like Dict[int, str]. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` +def _default_flytekit_decoder(data: bytes) -> Any: + return msgpack.unpackb(data, raw=False, strict_map_key=False) + + class BatchSize: """ This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): self._t = t self._name = name self._type_assertions_enabled = enable_type_assertions + self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {} + self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {} @property def name(self): @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + if binary_idl_object.tag == "msgpack": + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) + self._msgpack_decoder[expected_python_type] = decoder + return decoder.decode(binary_idl_object.value) + else: + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Cannot convert to type {expected_python_type}, only {self._type} is supported" ) + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal res = self._from_literal_transformer(lv) if type(res) != self._type: @@ -1697,17 +1714,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return None, None @staticmethod - def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: + def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: """ Creates a flyte-specific ``Literal`` value from a native python dictionary. """ from flytekit.types.pickle import FlytePickle try: - return Literal( - scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())), - metadata={"format": "json"}, - ) + msgpack_bytes = msgpack.dumps(v) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) except TypeError as e: if allow_pickle: remote_path = FlytePickle.to_pickle(ctx, v) @@ -1717,7 +1732,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L ), metadata={"format": "pickle"}, ) - raise e + raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}") @staticmethod def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: @@ -1768,7 +1783,7 @@ def to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_generic_literal(ctx, python_val, allow_pickle) + return self.dict_to_binary_literal(ctx, python_val, allow_pickle) lit_map = {} for k, v in python_val.items(): @@ -1785,6 +1800,9 @@ def to_literal( return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + if lv and lv.map and lv.map.literals is not None: tp = self.dict_types(expected_python_type) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index baf26fdffa..3046b07dd9 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -1,5 +1,7 @@ from datetime import datetime, timedelta from unittest import mock +import msgpack +import base64 import pytest from flyteidl.core.execution_pb2 import TaskExecution @@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value): if "pickle_check" in mock_return_value[0][0]: assert "pickle_file" in outputs["result"] else: + outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"])) assert ( outputs["result"]["EndpointConfigArn"] == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" diff --git a/plugins/flytekit-openai/tests/openai_batch/test_agent.py b/plugins/flytekit-openai/tests/openai_batch/test_agent.py index d9352e918b..476ca5c8ba 100644 --- a/plugins/flytekit-openai/tests/openai_batch/test_agent.py +++ b/plugins/flytekit-openai/tests/openai_batch/test_agent.py @@ -1,7 +1,8 @@ from datetime import timedelta from unittest import mock from unittest.mock import AsyncMock - +import msgpack +import base64 import pytest from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.openai.batch.agent import BatchEndpointMetadata @@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): outputs = literal_map_string_repr(resource.outputs) result = outputs["result"] - assert result == batch_retrieve_result.to_dict() + assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict() # Status: Failed mock_retrieve.return_value = batch_retrieve_result_failure diff --git a/pyproject.toml b/pyproject.toml index ba2cc46e83..82af81ee21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.11", + "msgpack>=1.1.0", "protobuf!=4.25.0", "pygments", "python-json-logger>=2.0.0", diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 72a95ac2dd..cf3e90e338 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -529,7 +529,7 @@ def test_stable_cache_key(): } ) key = _calculate_cache_key("task_name_1", "31415", lm) - assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb" + assert key == "task_name_1-31415-e3a85f91467d1e1f721ebe8129b2de31" @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index a8e4cd31a8..7e7439e796 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3218,10 +3218,10 @@ def test_union_file_directory(): (typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), (typing.List[Annotated[int, "tag"]], [1, 2, 3]), (typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), - (typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}), + (typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}), (typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), (typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), - (typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}), + (typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}), (typing.Union[int, str], 42), (typing.Union[int, str], "hello"), (typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py new file mode 100644 index 0000000000..34010e37cf --- /dev/null +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -0,0 +1,134 @@ +from datetime import datetime, date, timedelta + +import msgpack +from mashumaro.codecs.msgpack import MessagePackEncoder + +from flytekit.models.literals import Binary, Literal, Scalar +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine + +def test_simple_type_transformer(): + ctx = FlyteContextManager.current_context() + + int_inputs = [1, 2, 20240918, -1, -2, -20240918] + encoder = MessagePackEncoder(int) + for int_input in int_inputs: + int_msgpack_bytes = encoder.encode(int_input) + lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) + int_output = TypeEngine.to_python_value(ctx, lv, int) + assert int_input == int_output + + float_inputs = [2024.0918, 5.0, -2024.0918, -5.0] + encoder = MessagePackEncoder(float) + for float_input in float_inputs: + float_msgpack_bytes = encoder.encode(float_input) + lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) + float_output = TypeEngine.to_python_value(ctx, lv, float) + assert float_input == float_output + + bool_inputs = [True, False] + encoder = MessagePackEncoder(bool) + for bool_input in bool_inputs: + bool_msgpack_bytes = encoder.encode(bool_input) + lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) + bool_output = TypeEngine.to_python_value(ctx, lv, bool) + assert bool_input == bool_output + + str_inputs = ["hello", "world", "flyte", "kit", "is", "awesome"] + encoder = MessagePackEncoder(str) + for str_input in str_inputs: + str_msgpack_bytes = encoder.encode(str_input) + lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) + str_output = TypeEngine.to_python_value(ctx, lv, str) + assert str_input == str_output + + datetime_inputs = [datetime.now(), + datetime(2024, 9, 18), + datetime(2024, 9, 18, 1), + datetime(2024, 9, 18, 1, 1), + datetime(2024, 9, 18, 1, 1, 1), + datetime(2024, 9, 18, 1, 1, 1, 1)] + encoder = MessagePackEncoder(datetime) + for datetime_input in datetime_inputs: + datetime_msgpack_bytes = encoder.encode(datetime_input) + lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) + datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) + assert datetime_input == datetime_output + + date_inputs = [date.today(), + date(2024, 9, 18)] + encoder = MessagePackEncoder(date) + for date_input in date_inputs: + date_msgpack_bytes = encoder.encode(date_input) + lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) + date_output = TypeEngine.to_python_value(ctx, lv, date) + assert date_input == date_output + + timedelta_inputs = [timedelta(days=1), + timedelta(days=1, seconds=1), + timedelta(days=1, seconds=1, microseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1), + timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] + encoder = MessagePackEncoder(timedelta) + for timedelta_input in timedelta_inputs: + timedelta_msgpack_bytes = encoder.encode(timedelta_input) + lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) + timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) + assert timedelta_input == timedelta_output + +def test_untyped_dict(): + ctx = FlyteContextManager.current_context() + + dict_inputs = [ + # Basic key-value combinations with int, str, bool, float + {1: "a", "key": 2.5, True: False, 3.14: 100}, + {"a": 1, 2: "b", 3.5: True, False: 3.1415}, + + { + 1: [1, "a", 2.5, False], + "key_list": ["str", 3.14, True, 42], + True: [False, 2.718, "test"], + }, + + { + "nested_dict": {1: 2, "key": "value", True: 3.14, False: "string"}, + 3.14: {"pi": 3.14, "e": 2.718, 42: True}, + }, + + { + "list_in_dict": [ + {"inner_dict_1": [1, 2.5, "a"], "inner_dict_2": [True, False, 3.14]}, + [1, 2, 3, {"nested_list_dict": [False, "test"]}], + ] + }, + + { + "complex_nested": { + 1: {"nested_dict": {True: [1, "a", 2.5]}}, + "string_key": {False: {3.14: {"deep": [1, "deep_value"]}}}, + } + }, + + { + "list_of_dicts": [{"a": 1, "b": 2}, {"key1": "value1", "key2": "value2"}], + 10: [{"nested_list": [1, "value", 3.14]}, {"another_list": [True, False]}], + }, + + # More nested combinations of list and dict + { + "outer_list": [ + [1, 2, 3], + {"inner_dict": {"key1": [True, "string", 3.14], "key2": [1, 2.5]}}, # Dict inside list + ], + "another_dict": {"key1": {"subkey": [1, 2, "str"]}, "key2": [False, 3.14, "test"]}, + }, + ] + + for dict_input in dict_inputs: + dict_msgpack_bytes = msgpack.dumps(dict_input) + lv = Literal(scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag="msgpack"))) + dict_output = TypeEngine.to_python_value(ctx, lv, dict) + assert dict_input == dict_output diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 0e7b88bd08..c4b325a1a4 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -11,10 +11,10 @@ from dataclasses import dataclass from enum import Enum +import msgpack import pytest from dataclasses_json import DataClassJsonMixin from google.protobuf.struct_pb2 import Struct -from mashumaro.codecs.json import JSONEncoder, JSONDecoder from typing_extensions import Annotated, get_origin import flytekit @@ -37,6 +37,7 @@ from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter +from flytekit.models.literals import Binary from flytekit.models.task import Resources as _resource_models from flytekit.models.types import LiteralType, SimpleType from flytekit.tools.translator import get_serializable @@ -1488,7 +1489,7 @@ def t2(a: dict) -> str: guessed_types = {"a": pt} ctx = context_manager.FlyteContext.current_context() lm = TypeEngine.dict_to_literal_map(ctx, d=input_map, type_hints=guessed_types) - assert isinstance(lm.literals["a"].scalar.generic, Struct) + assert isinstance(lm.literals["a"].scalar.binary, Binary) output_lm = t2.dispatch_execute(ctx, lm) str_value = output_lm.literals["o0"].scalar.primitive.string_value @@ -1521,9 +1522,9 @@ def t2() -> dict: ctx = context_manager.FlyteContextManager.current_context() output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={})) - expected_struct = Struct() - expected_struct.update({"k1": "v1", "k2": 3, "4": {"one": [1, "two", [3]]}}) - assert output_lm.literals["o0"].scalar.generic == expected_struct + msgpack_bytes = msgpack.dumps({"k1": "v1", "k2": 3, 4: {"one": [1, "two", [3]]}}) + binary_idl_obj = Binary(value=msgpack_bytes, tag="msgpack") + assert output_lm.literals["o0"].scalar.binary == binary_idl_obj @pytest.mark.skipif(sys.version_info < (3, 9), reason="Use of dict hints is only supported in Python 3.9+")