From e3a258a01c82eea4c76ed21c0249bc542f111928 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 22:05:09 +0800 Subject: [PATCH 1/7] [flytekit][1][Simple Type] Binary IDL With MessagePack Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 37 ++++++++++++++++++++++++++---------- pyproject.toml | 1 + 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d42e2c2a54..b658bf5054 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -15,8 +15,9 @@ from abc import ABC, abstractmethod from collections import OrderedDict from functools import lru_cache -from typing import Dict, List, NamedTuple, Optional, Type, cast +from typing import Any, Dict, List, NamedTuple, Optional, Type, cast +import msgpack from dataclasses_json import DataClassJsonMixin, dataclass_json from flyteidl.core import literals_pb2 from google.protobuf import json_format as _json_format @@ -26,6 +27,7 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from mashumaro.codecs.json import JSONDecoder, JSONEncoder +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin @@ -42,15 +44,7 @@ from flytekit.models import types as _type_models from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel from flytekit.models.core import types as _core_types -from flytekit.models.literals import ( - Literal, - LiteralCollection, - LiteralMap, - Primitive, - Scalar, - Union, - Void, -) +from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType T = typing.TypeVar("T") @@ -58,6 +52,13 @@ TITLE = "title" +# In Mashumaro, the default encoder uses strict_map_key=False, while the default decoder uses strict_map_key=True. +# This is relevant for cases like Dict[int, str]. +# If strict_map_key=False is not used, the decoder will raise an error when trying to decode keys that are not strictly typed.` +def _default_flytekit_decoder(data: bytes) -> Any: + return msgpack.unpackb(data, raw=False, strict_map_key=False) + + class BatchSize: """ This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, @@ -129,6 +130,8 @@ def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): self._t = t self._name = name self._type_assertions_enabled = enable_type_assertions + self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {} + self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {} @property def name(self): @@ -221,6 +224,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) + def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> Optional[T]: + if binary_idl_object.tag == "msgpack": + try: + decoder = self._msgpack_decoder[expected_python_type] + except KeyError: + decoder = MessagePackDecoder(expected_python_type, pre_decoder_func=_default_flytekit_decoder) + self._msgpack_decoder[expected_python_type] = decoder + return decoder.decode(binary_idl_object.value) + else: + raise TypeTransformerFailedError(f"Unsupported binary format {binary_idl_object.tag}") + def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div @@ -271,6 +285,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Cannot convert to type {expected_python_type}, only {self._type} is supported" ) + if lv.scalar and lv.scalar.binary: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal res = self._from_literal_transformer(lv) if type(res) != self._type: diff --git a/pyproject.toml b/pyproject.toml index 8e8fcef90f..9d183b6a9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", "mashumaro>=3.11", + "msgpack>=1.1.0", "protobuf!=4.25.0", "pygments", "python-json-logger>=2.0.0", From 3562f0c54d8c6a2a3e041825159f14767d047deb Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 22:29:30 +0800 Subject: [PATCH 2/7] Add Tests Signed-off-by: Future-Outlier --- .../unit/core/test_type_engine_binary_idl.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/flytekit/unit/core/test_type_engine_binary_idl.py diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py new file mode 100644 index 0000000000..1e2fa3801b --- /dev/null +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -0,0 +1,73 @@ +import msgpack +from datetime import datetime, date, timedelta +from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder + +from flytekit.models.literals import Binary, Literal, Scalar +from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +""" +a: int = 1 +b: float = 1.0 +c: bool = True +d: str = "hello" +e: datetime = field(default_factory=datetime.now) +f: date = field(default_factory=date.today) +g: timedelta = field(default_factory=lambda: timedelta(days=1)) +""" +def test_simple_type_transformer(): + ctx = FlyteContextManager.current_context() + + int_input = 20240918 + encoder = MessagePackEncoder(int) + int_msgpack_bytes = encoder.encode(int_input) + lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) + int_output = TypeEngine.to_python_value(ctx, lv, int) + assert int_input == int_output + + float_input = 2024.0918 + encoder = MessagePackEncoder(float) + float_msgpack_bytes = encoder.encode(float_input) + lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) + float_output = TypeEngine.to_python_value(ctx, lv, float) + assert float_input == float_output + + bool_input = True + encoder = MessagePackEncoder(bool) + bool_msgpack_bytes = encoder.encode(bool_input) + lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) + bool_output = TypeEngine.to_python_value(ctx, lv, bool) + assert bool_input == bool_output + + bool_input = False + bool_msgpack_bytes = encoder.encode(bool_input) + lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) + bool_output = TypeEngine.to_python_value(ctx, lv, bool) + assert bool_input == bool_output + + str_input = "hello" + encoder = MessagePackEncoder(str) + str_msgpack_bytes = encoder.encode(str_input) + lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) + str_output = TypeEngine.to_python_value(ctx, lv, str) + assert str_input == str_output + + datetime_input = datetime.now() + encoder = MessagePackEncoder(datetime) + datetime_msgpack_bytes = encoder.encode(datetime_input) + lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) + datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) + assert datetime_input == datetime_output + + date_input = date.today() + encoder = MessagePackEncoder(date) + date_msgpack_bytes = encoder.encode(date_input) + lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) + date_output = TypeEngine.to_python_value(ctx, lv, date) + assert date_input == date_output + + timedelta_input = timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1) + encoder = MessagePackEncoder(timedelta) + timedelta_msgpack_bytes = encoder.encode(timedelta_input) + lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) + timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) + assert timedelta_input == timedelta_output From f93b44181db62f86e9f114b5b28cbb5817f56ff4 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 22:30:34 +0800 Subject: [PATCH 3/7] remove unused import Signed-off-by: Future-Outlier --- .../unit/core/test_type_engine_binary_idl.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 1e2fa3801b..8b99a5c819 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -1,19 +1,10 @@ -import msgpack from datetime import datetime, date, timedelta -from mashumaro.codecs.msgpack import MessagePackDecoder, MessagePackEncoder +from mashumaro.codecs.msgpack import MessagePackEncoder from flytekit.models.literals import Binary, Literal, Scalar from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine -""" -a: int = 1 -b: float = 1.0 -c: bool = True -d: str = "hello" -e: datetime = field(default_factory=datetime.now) -f: date = field(default_factory=date.today) -g: timedelta = field(default_factory=lambda: timedelta(days=1)) -""" + def test_simple_type_transformer(): ctx = FlyteContextManager.current_context() From c05a905b241fa825b1a8d78d57aa3e4b9d27becd Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 18 Sep 2024 23:43:19 +0800 Subject: [PATCH 4/7] [flytekit][2][untyped dict] Binary IDL With MessagePack Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 15 ++--- tests/flytekit/unit/core/test_local_cache.py | 2 +- .../unit/core/test_type_engine_binary_idl.py | 56 +++++++++++++++++++ tests/flytekit/unit/core/test_type_hints.py | 11 ++-- 4 files changed, 71 insertions(+), 13 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b658bf5054..d26c102b60 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1706,17 +1706,15 @@ def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple: return None, None @staticmethod - def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: + def dict_to_binary_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Literal: """ Creates a flyte-specific ``Literal`` value from a native python dictionary. """ from flytekit.types.pickle import FlytePickle try: - return Literal( - scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())), - metadata={"format": "json"}, - ) + msgpack_bytes = msgpack.dumps(v) + return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag="msgpack"))) except TypeError as e: if allow_pickle: remote_path = FlytePickle.to_pickle(ctx, v) @@ -1726,7 +1724,7 @@ def dict_to_generic_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> L ), metadata={"format": "pickle"}, ) - raise e + raise TypeTransformerFailedError(f"Cannot convert {v} to Flyte Literal.\n" f"Error Message: {e}") @staticmethod def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]: @@ -1777,7 +1775,7 @@ def to_literal( allow_pickle, base_type = DictTransformer.is_pickle(python_type) if expected and expected.simple and expected.simple == SimpleType.STRUCT: - return self.dict_to_generic_literal(ctx, python_val, allow_pickle) + return self.dict_to_binary_literal(ctx, python_val, allow_pickle) lit_map = {} for k, v in python_val.items(): @@ -1794,6 +1792,9 @@ def to_literal( return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: + if lv and lv.scalar and lv.scalar.binary is not None: + return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore + if lv and lv.map and lv.map.literals is not None: tp = self.dict_types(expected_python_type) diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 72a95ac2dd..cf3e90e338 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -529,7 +529,7 @@ def test_stable_cache_key(): } ) key = _calculate_cache_key("task_name_1", "31415", lm) - assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb" + assert key == "task_name_1-31415-e3a85f91467d1e1f721ebe8129b2de31" @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 8b99a5c819..318765c484 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -1,4 +1,6 @@ from datetime import datetime, date, timedelta + +import msgpack from mashumaro.codecs.msgpack import MessagePackEncoder from flytekit.models.literals import Binary, Literal, Scalar @@ -62,3 +64,57 @@ def test_simple_type_transformer(): lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) assert timedelta_input == timedelta_output + +def test_untyped_dict(): + ctx = FlyteContextManager.current_context() + + dict_inputs = [ + # Basic key-value combinations with int, str, bool, float + {1: "a", "key": 2.5, True: False, 3.14: 100}, + {"a": 1, 2: "b", 3.5: True, False: 3.1415}, + + { + 1: [1, "a", 2.5, False], + "key_list": ["str", 3.14, True, 42], + True: [False, 2.718, "test"], + }, + + { + "nested_dict": {1: 2, "key": "value", True: 3.14, False: "string"}, + 3.14: {"pi": 3.14, "e": 2.718, 42: True}, + }, + + { + "list_in_dict": [ + {"inner_dict_1": [1, 2.5, "a"], "inner_dict_2": [True, False, 3.14]}, + [1, 2, 3, {"nested_list_dict": [False, "test"]}], + ] + }, + + { + "complex_nested": { + 1: {"nested_dict": {True: [1, "a", 2.5]}}, + "string_key": {False: {3.14: {"deep": [1, "deep_value"]}}}, + } + }, + + { + "list_of_dicts": [{"a": 1, "b": 2}, {"key1": "value1", "key2": "value2"}], + 10: [{"nested_list": [1, "value", 3.14]}, {"another_list": [True, False]}], + }, + + # More nested combinations of list and dict + { + "outer_list": [ + [1, 2, 3], + {"inner_dict": {"key1": [True, "string", 3.14], "key2": [1, 2.5]}}, # Dict inside list + ], + "another_dict": {"key1": {"subkey": [1, 2, "str"]}, "key2": [False, 3.14, "test"]}, + }, + ] + + for dict_input in dict_inputs: + dict_msgpack_bytes = msgpack.dumps(dict_input) + lv = Literal(scalar=Scalar(binary=Binary(value=dict_msgpack_bytes, tag="msgpack"))) + dict_output = TypeEngine.to_python_value(ctx, lv, dict) + assert dict_input == dict_output diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 0e7b88bd08..c4b325a1a4 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -11,10 +11,10 @@ from dataclasses import dataclass from enum import Enum +import msgpack import pytest from dataclasses_json import DataClassJsonMixin from google.protobuf.struct_pb2 import Struct -from mashumaro.codecs.json import JSONEncoder, JSONDecoder from typing_extensions import Annotated, get_origin import flytekit @@ -37,6 +37,7 @@ from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types from flytekit.models.interface import Parameter +from flytekit.models.literals import Binary from flytekit.models.task import Resources as _resource_models from flytekit.models.types import LiteralType, SimpleType from flytekit.tools.translator import get_serializable @@ -1488,7 +1489,7 @@ def t2(a: dict) -> str: guessed_types = {"a": pt} ctx = context_manager.FlyteContext.current_context() lm = TypeEngine.dict_to_literal_map(ctx, d=input_map, type_hints=guessed_types) - assert isinstance(lm.literals["a"].scalar.generic, Struct) + assert isinstance(lm.literals["a"].scalar.binary, Binary) output_lm = t2.dispatch_execute(ctx, lm) str_value = output_lm.literals["o0"].scalar.primitive.string_value @@ -1521,9 +1522,9 @@ def t2() -> dict: ctx = context_manager.FlyteContextManager.current_context() output_lm = t2.dispatch_execute(ctx, _literal_models.LiteralMap(literals={})) - expected_struct = Struct() - expected_struct.update({"k1": "v1", "k2": 3, "4": {"one": [1, "two", [3]]}}) - assert output_lm.literals["o0"].scalar.generic == expected_struct + msgpack_bytes = msgpack.dumps({"k1": "v1", "k2": 3, 4: {"one": [1, "two", [3]]}}) + binary_idl_obj = Binary(value=msgpack_bytes, tag="msgpack") + assert output_lm.literals["o0"].scalar.binary == binary_idl_obj @pytest.mark.skipif(sys.version_info < (3, 9), reason="Use of dict hints is only supported in Python 3.9+") From b1bf20c2358db2a189e5e33beb492794d23d5e5b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Sep 2024 00:30:53 +0800 Subject: [PATCH 5/7] Fix Tests Signed-off-by: Future-Outlier --- plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py | 3 +++ plugins/flytekit-openai/tests/openai_batch/test_agent.py | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py index baf26fdffa..3046b07dd9 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py @@ -1,5 +1,7 @@ from datetime import datetime, timedelta from unittest import mock +import msgpack +import base64 import pytest from flyteidl.core.execution_pb2 import TaskExecution @@ -161,6 +163,7 @@ async def test_agent(mock_boto_call, mock_return_value): if "pickle_check" in mock_return_value[0][0]: assert "pickle_file" in outputs["result"] else: + outputs["result"] = msgpack.loads(base64.b64decode(outputs["result"])) assert ( outputs["result"]["EndpointConfigArn"] == "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config" diff --git a/plugins/flytekit-openai/tests/openai_batch/test_agent.py b/plugins/flytekit-openai/tests/openai_batch/test_agent.py index d9352e918b..476ca5c8ba 100644 --- a/plugins/flytekit-openai/tests/openai_batch/test_agent.py +++ b/plugins/flytekit-openai/tests/openai_batch/test_agent.py @@ -1,7 +1,8 @@ from datetime import timedelta from unittest import mock from unittest.mock import AsyncMock - +import msgpack +import base64 import pytest from flyteidl.core.execution_pb2 import TaskExecution from flytekitplugins.openai.batch.agent import BatchEndpointMetadata @@ -159,7 +160,7 @@ async def test_openai_batch_agent(mock_retrieve, mock_create, mock_context): outputs = literal_map_string_repr(resource.outputs) result = outputs["result"] - assert result == batch_retrieve_result.to_dict() + assert msgpack.loads(base64.b64decode(result)) == batch_retrieve_result.to_dict() # Status: Failed mock_retrieve.return_value = batch_retrieve_result_failure From 6b59d8989f1182078df4512ce304e566c163ab09 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Sep 2024 14:15:58 +0800 Subject: [PATCH 6/7] fix test_offloaded_literal Signed-off-by: Future-Outlier --- tests/flytekit/unit/core/test_type_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index a8e4cd31a8..7e7439e796 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3218,10 +3218,10 @@ def test_union_file_directory(): (typing.List[Color], [Color.RED, Color.GREEN, Color.BLUE]), (typing.List[Annotated[int, "tag"]], [1, 2, 3]), (typing.List[Annotated[str, "tag"]], ["a", "b", "c"]), - (typing.Dict[int, str], {"1": "a", "2": "b", "3": "c"}), + (typing.Dict[int, str], {1: "a", 2: "b", 3: "c"}), (typing.Dict[str, int], {"a": 1, "b": 2, "c": 3}), (typing.Dict[str, typing.List[int]], {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), - (typing.Dict[str, typing.Dict[int, str]], {"a": {"1": "a", "2": "b", "3": "c"}, "b": {"4": "d", "5": "e", "6": "f"}}), + (typing.Dict[str, typing.Dict[int, str]], {"a": {1: "a", 2: "b", 3: "c"}, "b": {4: "d", 5: "e", 6: "f"}}), (typing.Union[int, str], 42), (typing.Union[int, str], "hello"), (typing.Union[typing.List[int], typing.List[str]], [1, 2, 3]), From bcaf5735287524331e0d0af1ba25e94273a415ad Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 19 Sep 2024 15:08:19 +0800 Subject: [PATCH 7/7] Add more tests Signed-off-by: Future-Outlier --- .../unit/core/test_type_engine_binary_idl.py | 96 +++++++++++-------- 1 file changed, 55 insertions(+), 41 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine_binary_idl.py b/tests/flytekit/unit/core/test_type_engine_binary_idl.py index 8b99a5c819..f5b377f3dd 100644 --- a/tests/flytekit/unit/core/test_type_engine_binary_idl.py +++ b/tests/flytekit/unit/core/test_type_engine_binary_idl.py @@ -8,57 +8,71 @@ def test_simple_type_transformer(): ctx = FlyteContextManager.current_context() - int_input = 20240918 + int_inputs = [1, 2, 20240918, -1, -2, -20240918] encoder = MessagePackEncoder(int) - int_msgpack_bytes = encoder.encode(int_input) - lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) - int_output = TypeEngine.to_python_value(ctx, lv, int) - assert int_input == int_output + for int_input in int_inputs: + int_msgpack_bytes = encoder.encode(int_input) + lv = Literal(scalar=Scalar(binary=Binary(value=int_msgpack_bytes, tag="msgpack"))) + int_output = TypeEngine.to_python_value(ctx, lv, int) + assert int_input == int_output - float_input = 2024.0918 + float_inputs = [2024.0918, 5.0, -2024.0918, -5.0] encoder = MessagePackEncoder(float) - float_msgpack_bytes = encoder.encode(float_input) - lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) - float_output = TypeEngine.to_python_value(ctx, lv, float) - assert float_input == float_output + for float_input in float_inputs: + float_msgpack_bytes = encoder.encode(float_input) + lv = Literal(scalar=Scalar(binary=Binary(value=float_msgpack_bytes, tag="msgpack"))) + float_output = TypeEngine.to_python_value(ctx, lv, float) + assert float_input == float_output - bool_input = True + bool_inputs = [True, False] encoder = MessagePackEncoder(bool) - bool_msgpack_bytes = encoder.encode(bool_input) - lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) - bool_output = TypeEngine.to_python_value(ctx, lv, bool) - assert bool_input == bool_output + for bool_input in bool_inputs: + bool_msgpack_bytes = encoder.encode(bool_input) + lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) + bool_output = TypeEngine.to_python_value(ctx, lv, bool) + assert bool_input == bool_output - bool_input = False - bool_msgpack_bytes = encoder.encode(bool_input) - lv = Literal(scalar=Scalar(binary=Binary(value=bool_msgpack_bytes, tag="msgpack"))) - bool_output = TypeEngine.to_python_value(ctx, lv, bool) - assert bool_input == bool_output - - str_input = "hello" + str_inputs = ["hello", "world", "flyte", "kit", "is", "awesome"] encoder = MessagePackEncoder(str) - str_msgpack_bytes = encoder.encode(str_input) - lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) - str_output = TypeEngine.to_python_value(ctx, lv, str) - assert str_input == str_output + for str_input in str_inputs: + str_msgpack_bytes = encoder.encode(str_input) + lv = Literal(scalar=Scalar(binary=Binary(value=str_msgpack_bytes, tag="msgpack"))) + str_output = TypeEngine.to_python_value(ctx, lv, str) + assert str_input == str_output - datetime_input = datetime.now() + datetime_inputs = [datetime.now(), + datetime(2024, 9, 18), + datetime(2024, 9, 18, 1), + datetime(2024, 9, 18, 1, 1), + datetime(2024, 9, 18, 1, 1, 1), + datetime(2024, 9, 18, 1, 1, 1, 1)] encoder = MessagePackEncoder(datetime) - datetime_msgpack_bytes = encoder.encode(datetime_input) - lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) - datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) - assert datetime_input == datetime_output + for datetime_input in datetime_inputs: + datetime_msgpack_bytes = encoder.encode(datetime_input) + lv = Literal(scalar=Scalar(binary=Binary(value=datetime_msgpack_bytes, tag="msgpack"))) + datetime_output = TypeEngine.to_python_value(ctx, lv, datetime) + assert datetime_input == datetime_output - date_input = date.today() + date_inputs = [date.today(), + date(2024, 9, 18)] encoder = MessagePackEncoder(date) - date_msgpack_bytes = encoder.encode(date_input) - lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) - date_output = TypeEngine.to_python_value(ctx, lv, date) - assert date_input == date_output + for date_input in date_inputs: + date_msgpack_bytes = encoder.encode(date_input) + lv = Literal(scalar=Scalar(binary=Binary(value=date_msgpack_bytes, tag="msgpack"))) + date_output = TypeEngine.to_python_value(ctx, lv, date) + assert date_input == date_output - timedelta_input = timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1) + timedelta_inputs = [timedelta(days=1), + timedelta(days=1, seconds=1), + timedelta(days=1, seconds=1, microseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1), + timedelta(days=1, seconds=1, microseconds=1, milliseconds=1, minutes=1, hours=1, weeks=1), + timedelta(days=-1, seconds=-1, microseconds=-1, milliseconds=-1, minutes=-1, hours=-1, weeks=-1)] encoder = MessagePackEncoder(timedelta) - timedelta_msgpack_bytes = encoder.encode(timedelta_input) - lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) - timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) - assert timedelta_input == timedelta_output + for timedelta_input in timedelta_inputs: + timedelta_msgpack_bytes = encoder.encode(timedelta_input) + lv = Literal(scalar=Scalar(binary=Binary(value=timedelta_msgpack_bytes, tag="msgpack"))) + timedelta_output = TypeEngine.to_python_value(ctx, lv, timedelta) + assert timedelta_input == timedelta_output