diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 8be9d8ccae..90f246deb9 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -92,6 +92,12 @@ def extract_value( if len(input_val) == 0: raise sub_type = type(input_val[0]) + # To maintain consistency between translate_inputs_to_literals and ListTransformer.to_literal for batchable types, + # directly call ListTransformer.to_literal to batch process the list items. This is necessary because processing + # each list item separately could lead to errors since ListTransformer.to_python_value may treat the literal + # as it is batched for batchable types. + if ListTransformer.is_batchable(python_type): + return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 306c4116ad..e5ffa6459c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -968,12 +968,40 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") + @staticmethod + def is_batchable(t: Type): + """ + This function evaluates whether the provided type is batchable or not. + It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle. + """ + from flytekit.types.pickle import FlytePickle + + if get_origin(t) is Annotated: + return ListTransformer.is_batchable(get_args(t)[0]) + if get_origin(t) is list: + subtype = get_args(t)[0] + if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): + return True + return False + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - t = self.get_sub_type(python_type) - lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore + if ListTransformer.is_batchable(python_type): + from flytekit.types.pickle.pickle import BatchSize, FlytePickle + + batchSize = len(python_val) # default batch size + # parse annotated to get the number of items saved in a pickle file. + if get_origin(python_type) is Annotated: + for annotation in get_args(python_type)[1:]: + if isinstance(annotation, BatchSize): + batchSize = annotation.val + break + lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore + else: + t = self.get_sub_type(python_type) + lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore @@ -981,9 +1009,18 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() - - st = self.get_sub_type(expected_python_type) - return [TypeEngine.to_python_value(ctx, x, st) for x in lits] + if self.is_batchable(expected_python_type): + from flytekit.types.pickle import FlytePickle + + batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] + if len(batch_list) > 0 and type(batch_list[0]) is list: + # Make it have backward compatibility. The upstream task may use old version of Flytekit that + # won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. + return [item for batch in batch_list for item in batch] + return batch_list + else: + st = self.get_sub_type(expected_python_type) + return [TypeEngine.to_python_value(ctx, x, st) for x in lits] def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: @@ -1044,7 +1081,7 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool: if len(ucols) != len(dcols): return False - for (u, d) in zip(ucols, dcols): + for u, d in zip(ucols, dcols): if u.name != d.name: return False diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py index 65604e67bb..e5bd1c056d 100644 --- a/flytekit/types/pickle/__init__.py +++ b/flytekit/types/pickle/__init__.py @@ -9,4 +9,4 @@ FlytePickle """ -from .pickle import FlytePickle +from .pickle import BatchSize, FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 3472dec7e6..3de75b765b 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -13,6 +13,19 @@ T = typing.TypeVar("T") +class BatchSize: + """ + Flyte-specific object used to wrap the hash function for a specific type + """ + + def __init__(self, val: int): + self._val = val + + @property + def val(self) -> int: + return self._val + + class FlytePickle(typing.Generic[T]): """ This type is only used by flytekit internally. User should not use this type. diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 7ceec809b1..c45e200f95 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -14,7 +14,7 @@ from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType from flytekit.tools.translator import get_serializable -from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -55,6 +55,11 @@ def test_get_literal_type(): ) +def test_batch_size(): + bs = BatchSize(5) + assert bs.val == 5 + + def test_nested(): class Foo(object): def __init__(self, number: int): diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index d8b043116e..88f85c9153 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -3,6 +3,7 @@ import pytest from dataclasses_json import dataclass_json +from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager @@ -14,6 +15,8 @@ translate_inputs_to_literals, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import BatchSize def test_create_and_link_node(): @@ -92,7 +95,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -102,7 +105,7 @@ class MyDataclass(object): a: typing.List[str] @task - def t1(a: typing.Union[float, typing.List[int], MyDataclass]): + def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], BatchSize(2)]]): print(a) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 5570231b95..d70174c77a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -51,7 +51,7 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle -from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer from flytekit.types.schema import FlyteSchema from flytekit.types.schema.types_pandas import PandasDataFrameTransformer from flytekit.types.structured.structured_dataset import StructuredDataset @@ -1574,3 +1574,65 @@ def test_file_ext_with_flyte_file_wrong_type(): with pytest.raises(ValueError) as e: FlyteFile[WRONG_TYPE] assert str(e.value) == "Underlying type of File Extension must be of type " + + +def test_is_batchable(): + assert ListTransformer.is_batchable(typing.List[int]) is False + assert ListTransformer.is_batchable(typing.List[str]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False + assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False + + assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True + assert ( + ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)]) + is True + ) + + +@pytest.mark.parametrize( + "python_val, python_type, expected_list_length", + [ + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]), + ], +) +def test_batch_pickle_list(python_val, python_type, expected_list_length): + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + tmp_lv = lv + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. + assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val