diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8a3e9195a8..eacf22c7ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -970,12 +970,18 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: @staticmethod def is_batchable(t: Type): + """ + This function determines whether a given list is batchable or not. + A batchable list consists only FlytePickle objects. + """ from flytekit.types.pickle import FlytePickle - if t == FlytePickle or (hasattr(t, "__origin__") and t.__origin__ == FlytePickle): - return True - if get_origin(t) is not None: - return any(map(ListTransformer.is_batchable, get_args(t))) + if get_origin(t) is Annotated: + return ListTransformer.is_batchable(get_args(t)[0]) + if get_origin(t) is list: + subtype = get_args(t)[0] + if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): + return True return False def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 7dd1acc742..25817b68c2 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -94,7 +94,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], [{"a": {0: "foo"}}] * 5], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -105,7 +105,7 @@ class MyDataclass(object): @task def t1( - a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2]] + a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]] ): print(a) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 90d2c5510c..ccd44a9d96 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1577,24 +1577,49 @@ def test_file_ext_with_flyte_file_wrong_type(): @pytest.mark.parametrize( - "python_val, python_type, batch_size", + "python_val, python_type, batch_size, expected_list_length", [ - ([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5), - ( - [{"a": {0: "foo"}}] * 5, - Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], - 2, - ), + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], 5, [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], 2, [3]), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], 2, [2, 1]), ], ) -def test_batch_pickle_list(python_val, python_type, batch_size): - from math import ceil - +def test_batch_pickle_list(python_val, python_type, batch_size, expected_list_length): ctx = FlyteContext.current_context() expected = TypeEngine.to_literal_type(python_type) lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. - # By default, the batch_size is set to the length of the whole list. - assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) + + tmp_lv = lv + tmp_python_val = python_val + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + tmp_python_val = tmp_python_val[0] + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val