From 06b7b174a065d38bb844e6f995635cdbeccb81fc Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 6 Mar 2023 02:08:21 +0000 Subject: [PATCH 01/17] Convert List[Any] to a single pickle file --- flytekit/core/type_engine.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 306c4116ad..535598dc62 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -968,12 +968,34 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") + def getFlytePickle(self): + for _, transformer in TypeEngine._REGISTRY.items(): + if transformer.name == "FlytePickle": + return transformer + return None + + def conatinFlytePickle(self, t, FlytePickle): + if hasattr(t, "__origin__") and t.__origin__ == FlytePickle: + return True + elif get_origin(t) is not None: + return any(map(lambda x: self.conatinFlytePickle(x, FlytePickle), get_args(t))) + else: + return False + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") t = self.get_sub_type(python_type) - lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore + flytePickleTransformer = self.getFlytePickle() + flytePickle = flytePickleTransformer.python_type + batchSize = 1 # default batch size + if get_origin(python_type) is Annotated: + batchSize = get_args(python_type)[1] + if flytePickleTransformer is not None and self.conatinFlytePickle(python_type, flytePickle): + lit_list = [flytePickleTransformer.to_literal(ctx, python_val[i : i + batchSize], t, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore + else: + lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore @@ -981,9 +1003,14 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() - st = self.get_sub_type(expected_python_type) - return [TypeEngine.to_python_value(ctx, x, st) for x in lits] + flytePickleTransformer = self.getFlytePickle() + flytePickle = flytePickleTransformer.python_type + if flytePickleTransformer is not None and self.conatinFlytePickle(expected_python_type, flytePickle): + batchList = [flytePickleTransformer.to_python_value(ctx, batch, st) for batch in lits] + return [item for batch in batchList for item in batch] + else: + return [TypeEngine.to_python_value(ctx, x, st) for x in lits] def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: From 34936fa0505042cb0b7048053aabfa32c7d562c2 Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 6 Mar 2023 03:04:54 +0000 Subject: [PATCH 02/17] remove redundant code --- flytekit/core/type_engine.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 535598dc62..2f13c982d0 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -971,10 +971,10 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: def getFlytePickle(self): for _, transformer in TypeEngine._REGISTRY.items(): if transformer.name == "FlytePickle": - return transformer + return transformer.python_type return None - def conatinFlytePickle(self, t, FlytePickle): + def conatinFlytePickle(self, t: Type, FlytePickle: Type[T]): if hasattr(t, "__origin__") and t.__origin__ == FlytePickle: return True elif get_origin(t) is not None: @@ -987,13 +987,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise TypeTransformerFailedError("Expected a list") t = self.get_sub_type(python_type) - flytePickleTransformer = self.getFlytePickle() - flytePickle = flytePickleTransformer.python_type + flytePickle = self.getFlytePickle() batchSize = 1 # default batch size if get_origin(python_type) is Annotated: batchSize = get_args(python_type)[1] - if flytePickleTransformer is not None and self.conatinFlytePickle(python_type, flytePickle): - lit_list = [flytePickleTransformer.to_literal(ctx, python_val[i : i + batchSize], t, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore + if flytePickle is not None and self.conatinFlytePickle(python_type, flytePickle): + lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], flytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) @@ -1004,10 +1003,9 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: except AttributeError: raise TypeTransformerFailedError() st = self.get_sub_type(expected_python_type) - flytePickleTransformer = self.getFlytePickle() - flytePickle = flytePickleTransformer.python_type - if flytePickleTransformer is not None and self.conatinFlytePickle(expected_python_type, flytePickle): - batchList = [flytePickleTransformer.to_python_value(ctx, batch, st) for batch in lits] + flytePickle = self.getFlytePickle() + if flytePickle is not None and self.conatinFlytePickle(expected_python_type, flytePickle): + batchList = [TypeEngine.to_python_value(ctx, batch, flytePickle) for batch in lits] return [item for batch in batchList for item in batch] else: return [TypeEngine.to_python_value(ctx, x, st) for x in lits] From bc1202b52555ca28b362a0f26d7e131cdeb4ed5a Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 6 Mar 2023 03:17:42 +0000 Subject: [PATCH 03/17] keep batchSize only if type contain flytePickle --- flytekit/core/type_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2f13c982d0..98a988e74c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -988,10 +988,10 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp t = self.get_sub_type(python_type) flytePickle = self.getFlytePickle() - batchSize = 1 # default batch size - if get_origin(python_type) is Annotated: - batchSize = get_args(python_type)[1] if flytePickle is not None and self.conatinFlytePickle(python_type, flytePickle): + batchSize = 1 # default batch size + if get_origin(python_type) is Annotated: + batchSize = get_args(python_type)[1] lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], flytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore From fbe9102c526f9a2d0dab51d69a4771aed3fb5686 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 6 Mar 2023 20:39:25 +0000 Subject: [PATCH 04/17] fix error --- flytekit/core/type_engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 98a988e74c..3d01903315 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -989,7 +989,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp t = self.get_sub_type(python_type) flytePickle = self.getFlytePickle() if flytePickle is not None and self.conatinFlytePickle(python_type, flytePickle): - batchSize = 1 # default batch size + batchSize = len(python_val) # default batch size if get_origin(python_type) is Annotated: batchSize = get_args(python_type)[1] lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], flytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore @@ -1006,7 +1006,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: flytePickle = self.getFlytePickle() if flytePickle is not None and self.conatinFlytePickle(expected_python_type, flytePickle): batchList = [TypeEngine.to_python_value(ctx, batch, flytePickle) for batch in lits] - return [item for batch in batchList for item in batch] + # TODO: to_literal and to_python_value is not symmetric + return [item for batch in batchList for item in batch] if type(batchList[0]) == list else batchList else: return [TypeEngine.to_python_value(ctx, x, st) for x in lits] From 6bb4fc5a4af108ce958ee6b8ed9e4fd872ed178d Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Tue, 7 Mar 2023 05:32:20 +0000 Subject: [PATCH 05/17] add batch support to translate_inputs_to_literals --- flytekit/core/promise.py | 3 +++ flytekit/core/type_engine.py | 27 +++++++++++---------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 8be9d8ccae..6d9273cd24 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -29,6 +29,7 @@ from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive from flytekit.models.types import SimpleType +from flytekit.types.pickle import FlytePickle def translate_inputs_to_literals( @@ -92,6 +93,8 @@ def extract_value( if len(input_val) == 0: raise sub_type = type(input_val[0]) + if ListTransformer.isBatchable(python_type, FlytePickle): + return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) elif isinstance(input_val, dict): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3d01903315..55b3050cca 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -968,46 +968,41 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") - def getFlytePickle(self): - for _, transformer in TypeEngine._REGISTRY.items(): - if transformer.name == "FlytePickle": - return transformer.python_type - return None - - def conatinFlytePickle(self, t: Type, FlytePickle: Type[T]): + @staticmethod + def isBatchable(t: Type, FlytePickle: Type[T]): + print("hi!!!!!!!!!!!!!!",) if hasattr(t, "__origin__") and t.__origin__ == FlytePickle: return True elif get_origin(t) is not None: - return any(map(lambda x: self.conatinFlytePickle(x, FlytePickle), get_args(t))) + return any(map(lambda x: ListTransformer.isBatchable(x, FlytePickle), get_args(t))) else: return False def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + from flytekit.types.pickle import FlytePickle if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") t = self.get_sub_type(python_type) - flytePickle = self.getFlytePickle() - if flytePickle is not None and self.conatinFlytePickle(python_type, flytePickle): + if self.isBatchable(python_type, FlytePickle): batchSize = len(python_val) # default batch size if get_origin(python_type) is Annotated: batchSize = get_args(python_type)[1] - lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], flytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore + lit_list = [TypeEngine.to_literal(ctx, python_val[i: i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore + from flytekit.types.pickle import FlytePickle try: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() st = self.get_sub_type(expected_python_type) - flytePickle = self.getFlytePickle() - if flytePickle is not None and self.conatinFlytePickle(expected_python_type, flytePickle): - batchList = [TypeEngine.to_python_value(ctx, batch, flytePickle) for batch in lits] - # TODO: to_literal and to_python_value is not symmetric - return [item for batch in batchList for item in batch] if type(batchList[0]) == list else batchList + if self.isBatchable(expected_python_type, FlytePickle): + batchList = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] + return [item for batch in batchList for item in batch] else: return [TypeEngine.to_python_value(ctx, x, st) for x in lits] From 82d1f8539d13cbbe3a1b969456b8b93f4c1d5fe1 Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Thu, 9 Mar 2023 02:23:47 +0000 Subject: [PATCH 06/17] add ci test Signed-off-by: Yicheng-Lu-llll --- flytekit/core/promise.py | 3 +- flytekit/core/type_engine.py | 33 +++++++++++--------- tests/flytekit/unit/core/test_type_engine.py | 28 ++++++++++++++++- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 6d9273cd24..eaccd6c1d0 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -29,7 +29,6 @@ from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive from flytekit.models.types import SimpleType -from flytekit.types.pickle import FlytePickle def translate_inputs_to_literals( @@ -93,7 +92,7 @@ def extract_value( if len(input_val) == 0: raise sub_type = type(input_val[0]) - if ListTransformer.isBatchable(python_type, FlytePickle): + if ListTransformer.is_batchable(python_type): return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list)) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 55b3050cca..3f109ee858 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -969,41 +969,44 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: raise ValueError(f"Type of Generic List type is not supported, {e}") @staticmethod - def isBatchable(t: Type, FlytePickle: Type[T]): - print("hi!!!!!!!!!!!!!!",) - if hasattr(t, "__origin__") and t.__origin__ == FlytePickle: + def is_batchable(t: Type): + from flytekit.types.pickle import FlytePickle + + if t == FlytePickle or (hasattr(t, "__origin__") and t.__origin__ == FlytePickle): return True - elif get_origin(t) is not None: - return any(map(lambda x: ListTransformer.isBatchable(x, FlytePickle), get_args(t))) - else: - return False + if get_origin(t) is not None: + return any(map(ListTransformer.is_batchable, get_args(t))) + return False def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: from flytekit.types.pickle import FlytePickle + if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - t = self.get_sub_type(python_type) - if self.isBatchable(python_type, FlytePickle): + if self.is_batchable(python_type): batchSize = len(python_val) # default batch size + # parse annotated to get the number of items saved in a pickle file. if get_origin(python_type) is Annotated: batchSize = get_args(python_type)[1] - lit_list = [TypeEngine.to_literal(ctx, python_val[i: i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore + lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: + t = self.get_sub_type(python_type) lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore from flytekit.types.pickle import FlytePickle + try: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() - st = self.get_sub_type(expected_python_type) - if self.isBatchable(expected_python_type, FlytePickle): - batchList = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] - return [item for batch in batchList for item in batch] + if self.is_batchable(expected_python_type): + batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] + return [item for batch in batch_list for item in batch] else: + st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lits] def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore @@ -1065,7 +1068,7 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool: if len(ucols) != len(dcols): return False - for (u, d) in zip(ucols, dcols): + for u, d in zip(ucols, dcols): if u.name != d.name: return False diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 842ae7a98c..25e3ac36a7 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1574,3 +1574,29 @@ def test_file_ext_with_flyte_file_wrong_type(): with pytest.raises(ValueError) as e: FlyteFile[WRONG_TYPE] assert str(e.value) == "Underlying type of File Extension must be of type " + + +def test_batch_pickle_list(): + from math import ceil + + python_val = [{"a": {0: "foo"}}] * 5 + python_type_list = [ + typing.List[typing.Dict[str, FlytePickle]], + Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], + ] + + for python_type in python_type_list: + batch_size = len(python_val) + if get_origin(python_type) is Annotated: + batch_size = get_args(python_type)[1] + + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + # For example, if batch_size = 2, then the list should be split into ceil(5/3) = 3 chunks. + # By default, the batch_size is set to the length of the list. + assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) + + pv = TypeEngine.to_python_value(ctx, lv, python_type) + assert pv == python_val From 59b512b3c56ac665b08132fc6f56e793f652c8dd Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Thu, 9 Mar 2023 02:49:44 +0000 Subject: [PATCH 07/17] improve comment Signed-off-by: Yicheng-Lu-llll --- tests/flytekit/unit/core/test_type_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 25e3ac36a7..d87493eb05 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1594,8 +1594,8 @@ def test_batch_pickle_list(): expected = TypeEngine.to_literal_type(python_type) lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - # For example, if batch_size = 2, then the list should be split into ceil(5/3) = 3 chunks. - # By default, the batch_size is set to the length of the list. + # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. + # By default, the batch_size is set to the length of the whole list. assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) pv = TypeEngine.to_python_value(ctx, lv, python_type) From 3e5d5d9f2f7514a6a01e6052e964089012793b0e Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Thu, 9 Mar 2023 03:20:03 +0000 Subject: [PATCH 08/17] add more ci test Signed-off-by: Yicheng-Lu-llll --- tests/flytekit/unit/core/test_promise.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index d8b043116e..7dd1acc742 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -3,6 +3,7 @@ import pytest from dataclasses_json import dataclass_json +from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow from flytekit.core import context_manager @@ -14,6 +15,7 @@ translate_inputs_to_literals, ) from flytekit.exceptions.user import FlyteAssertion +from flytekit.types.pickle import FlytePickle def test_create_and_link_node(): @@ -92,7 +94,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], [{"a": {0: "foo"}}] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -102,7 +104,9 @@ class MyDataclass(object): a: typing.List[str] @task - def t1(a: typing.Union[float, typing.List[int], MyDataclass]): + def t1( + a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2]] + ): print(a) ctx = context_manager.FlyteContext.current_context() From a6ec8e69b7de95f10f9e08f363c5a01119fc022a Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Sat, 11 Mar 2023 07:15:45 +0000 Subject: [PATCH 09/17] improve Signed-off-by: Yicheng-Lu-llll --- flytekit/core/type_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3f109ee858..956477f854 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -984,10 +984,10 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") - if self.is_batchable(python_type): + if ListTransformer.is_batchable(python_type): batchSize = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. - if get_origin(python_type) is Annotated: + if get_origin(python_type) is Annotated and type(get_args(python_type)[1]) == int: batchSize = get_args(python_type)[1] lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: From 4b2c1a51811bdd06cfc1cf20d24274a6d9cfd474 Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 13 Mar 2023 01:54:40 +0000 Subject: [PATCH 10/17] handle HashMethod case Signed-off-by: Yicheng-Lu-llll --- flytekit/core/type_engine.py | 7 +++- tests/flytekit/unit/core/test_type_engine.py | 41 +++++++++----------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 956477f854..8a3e9195a8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -987,8 +987,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if ListTransformer.is_batchable(python_type): batchSize = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. - if get_origin(python_type) is Annotated and type(get_args(python_type)[1]) == int: - batchSize = get_args(python_type)[1] + if get_origin(python_type) is Annotated: + for annotation in get_args(python_type)[1:]: + if isinstance(annotation, int): + batchSize = annotation + break lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: t = self.get_sub_type(python_type) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d87493eb05..f04b61651d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1576,27 +1576,22 @@ def test_file_ext_with_flyte_file_wrong_type(): assert str(e.value) == "Underlying type of File Extension must be of type " -def test_batch_pickle_list(): +@pytest.mark.parametrize( + "python_val, python_type, batch_size", + [ + ([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5), + ([{"a": {0: "foo"}}] * 5, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], 2), + ([{"a": {0: "foo"}}] * 6, Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], 2), + ], +) +def test_batch_pickle_list(python_val, python_type, batch_size): from math import ceil - python_val = [{"a": {0: "foo"}}] * 5 - python_type_list = [ - typing.List[typing.Dict[str, FlytePickle]], - Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], - ] - - for python_type in python_type_list: - batch_size = len(python_val) - if get_origin(python_type) is Annotated: - batch_size = get_args(python_type)[1] - - ctx = FlyteContext.current_context() - expected = TypeEngine.to_literal_type(python_type) - - lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. - # By default, the batch_size is set to the length of the whole list. - assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) - - pv = TypeEngine.to_python_value(ctx, lv, python_type) - assert pv == python_val + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. + # By default, the batch_size is set to the length of the whole list. + assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) + pv = TypeEngine.to_python_value(ctx, lv, python_type) + assert pv == python_val From ecbb7480d64816d4a79882acb44a076b806b16cc Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 13 Mar 2023 04:38:22 +0000 Subject: [PATCH 11/17] improve format Signed-off-by: Yicheng-Lu-llll --- tests/flytekit/unit/core/test_type_engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index f04b61651d..90d2c5510c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1580,8 +1580,11 @@ def test_file_ext_with_flyte_file_wrong_type(): "python_val, python_type, batch_size", [ ([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5), - ([{"a": {0: "foo"}}] * 5, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], 2), - ([{"a": {0: "foo"}}] * 6, Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], 2), + ( + [{"a": {0: "foo"}}] * 5, + Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], + 2, + ), ], ) def test_batch_pickle_list(python_val, python_type, batch_size): From 0403c48307aa6ab764e1487f1e4b15f02d82d56f Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Sat, 25 Mar 2023 22:29:23 +0000 Subject: [PATCH 12/17] improve Signed-off-by: Yicheng-Lu-llll --- flytekit/core/type_engine.py | 14 ++++-- tests/flytekit/unit/core/test_promise.py | 4 +- tests/flytekit/unit/core/test_type_engine.py | 53 ++++++++++++++------ 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8a3e9195a8..eacf22c7ce 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -970,12 +970,18 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: @staticmethod def is_batchable(t: Type): + """ + This function determines whether a given list is batchable or not. + A batchable list consists only FlytePickle objects. + """ from flytekit.types.pickle import FlytePickle - if t == FlytePickle or (hasattr(t, "__origin__") and t.__origin__ == FlytePickle): - return True - if get_origin(t) is not None: - return any(map(ListTransformer.is_batchable, get_args(t))) + if get_origin(t) is Annotated: + return ListTransformer.is_batchable(get_args(t)[0]) + if get_origin(t) is list: + subtype = get_args(t)[0] + if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle): + return True return False def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 7dd1acc742..25817b68c2 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -94,7 +94,7 @@ def wf(i: int, j: int): @pytest.mark.parametrize( "input", - [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], [{"a": {0: "foo"}}] * 5], + [2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5], ) def test_translate_inputs_to_literals(input): @dataclass_json @@ -105,7 +105,7 @@ class MyDataclass(object): @task def t1( - a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2]] + a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]] ): print(a) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 90d2c5510c..ccd44a9d96 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated +from typing_extensions import Annotated, get_args, get_origin from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1577,24 +1577,49 @@ def test_file_ext_with_flyte_file_wrong_type(): @pytest.mark.parametrize( - "python_val, python_type, batch_size", + "python_val, python_type, batch_size, expected_list_length", [ - ([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5), - ( - [{"a": {0: "foo"}}] * 5, - Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], - 2, - ), + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], 5, [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], 2, [3]), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], 2, [2, 1]), ], ) -def test_batch_pickle_list(python_val, python_type, batch_size): - from math import ceil - +def test_batch_pickle_list(python_val, python_type, batch_size, expected_list_length): ctx = FlyteContext.current_context() expected = TypeEngine.to_literal_type(python_type) lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. - # By default, the batch_size is set to the length of the whole list. - assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) + + tmp_lv = lv + tmp_python_val = python_val + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + tmp_python_val = tmp_python_val[0] + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val From f9761f5f77d9d82e05ff7e664e51b8e66267e78b Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 27 Mar 2023 04:37:28 +0000 Subject: [PATCH 13/17] improve Signed-off-by: Yicheng-Lu-llll --- tests/flytekit/unit/core/test_promise.py | 4 +--- tests/flytekit/unit/core/test_type_engine.py | 13 ++++++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 25817b68c2..533ae35c48 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -104,9 +104,7 @@ class MyDataclass(object): a: typing.List[str] @task - def t1( - a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]] - ): + def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]]): print(a) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index ccd44a9d96..4dc5fd6feb 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1577,39 +1577,38 @@ def test_file_ext_with_flyte_file_wrong_type(): @pytest.mark.parametrize( - "python_val, python_type, batch_size, expected_list_length", + "python_val, python_type, expected_list_length", [ # Case 1: List of FlytePickle objects with default batch size. # (By default, the batch_size is set to the length of the whole list.) # After converting to literal, the result will be [batched_FlytePickle(5 items)]. # Therefore, the expected list length is [1]. - ([{"foo"}] * 5, typing.List[FlytePickle], 5, [1]), + ([{"foo"}] * 5, typing.List[FlytePickle], [1]), # Case 2: List of FlytePickle objects with batch size 2. # After converting to literal, the result will be # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. # Therefore, the expected list length is [3]. - (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], 2, [3]), + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], [3]), # Case 3: Nested list of FlytePickle objects with batch size 2. # After converting to literal, the result will be # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). - ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], 2, [2, 1]), + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], [2, 1]), ], ) -def test_batch_pickle_list(python_val, python_type, batch_size, expected_list_length): +def test_batch_pickle_list(python_val, python_type, expected_list_length): ctx = FlyteContext.current_context() expected = TypeEngine.to_literal_type(python_type) lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) tmp_lv = lv - tmp_python_val = python_val + for length in expected_list_length: # Check that after converting to literal, the length of the literal list is equal to: # - the length of the original list divided by the batch size if not nested # - the length of the original list if it contains a nested list assert len(tmp_lv.collection.literals) == length tmp_lv = tmp_lv.collection.literals[0] - tmp_python_val = tmp_python_val[0] pv = TypeEngine.to_python_value(ctx, lv, python_type) # Check that after converting literal to Python value, the result is equal to the original python values. From 159141bdf61f144ff8c1ddd071af69e02ae3e78a Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Mon, 27 Mar 2023 04:42:27 +0000 Subject: [PATCH 14/17] improve Signed-off-by: Yicheng-Lu-llll --- tests/flytekit/unit/core/test_type_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 4dc5fd6feb..49d8d3c36d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1602,7 +1602,6 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length): lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) tmp_lv = lv - for length in expected_list_length: # Check that after converting to literal, the length of the literal list is equal to: # - the length of the original list divided by the batch size if not nested From 8a216736e60a7ecac6a15dc3f682d8cfaaabe3fb Mon Sep 17 00:00:00 2001 From: Yicheng-Lu-llll Date: Thu, 30 Mar 2023 04:07:14 +0000 Subject: [PATCH 15/17] add test_is_batchable Signed-off-by: Yicheng-Lu-llll --- flytekit/core/promise.py | 4 ++++ flytekit/core/type_engine.py | 12 ++++++------ tests/flytekit/unit/core/test_type_engine.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index eaccd6c1d0..90f246deb9 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -92,6 +92,10 @@ def extract_value( if len(input_val) == 0: raise sub_type = type(input_val[0]) + # To maintain consistency between translate_inputs_to_literals and ListTransformer.to_literal for batchable types, + # directly call ListTransformer.to_literal to batch process the list items. This is necessary because processing + # each list item separately could lead to errors since ListTransformer.to_python_value may treat the literal + # as it is batched for batchable types. if ListTransformer.is_batchable(python_type): return TypeEngine.to_literal(ctx, input_val, python_type, lt) literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val] diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index eacf22c7ce..302ef0e712 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -971,8 +971,8 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: @staticmethod def is_batchable(t: Type): """ - This function determines whether a given list is batchable or not. - A batchable list consists only FlytePickle objects. + This function evaluates whether the provided type is batchable or not. + It returns True only if the type is either List or Annotated(List) and the List subtype is FlytePickle. """ from flytekit.types.pickle import FlytePickle @@ -985,12 +985,12 @@ def is_batchable(t: Type): return False def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: - from flytekit.types.pickle import FlytePickle - if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") if ListTransformer.is_batchable(python_type): + from flytekit.types.pickle import FlytePickle + batchSize = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. if get_origin(python_type) is Annotated: @@ -1005,13 +1005,13 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(collection=LiteralCollection(literals=lit_list)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore - from flytekit.types.pickle import FlytePickle - try: lits = lv.collection.literals except AttributeError: raise TypeTransformerFailedError() if self.is_batchable(expected_python_type): + from flytekit.types.pickle import FlytePickle + batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] return [item for batch in batch_list for item in batch] else: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 49d8d3c36d..214cc309c2 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1576,6 +1576,18 @@ def test_file_ext_with_flyte_file_wrong_type(): assert str(e.value) == "Underlying type of File Extension must be of type " +def test_is_batchable(): + assert ListTransformer.is_batchable(typing.List[int]) is False + assert ListTransformer.is_batchable(typing.List[str]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict]) is False + assert ListTransformer.is_batchable(typing.List[typing.Dict[str, FlytePickle]]) is False + assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False + + assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], 3]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), 3]) is True + + @pytest.mark.parametrize( "python_val, python_type, expected_list_length", [ From 0977021f5d7f74fca0e349b28e9c4c39258c87c0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 3 Apr 2023 15:42:46 -0700 Subject: [PATCH 16/17] Add BatchSize Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 12 ++++++++---- flytekit/types/pickle/__init__.py | 2 +- flytekit/types/pickle/pickle.py | 13 +++++++++++++ tests/flytekit/unit/core/test_promise.py | 3 ++- tests/flytekit/unit/core/test_type_engine.py | 15 +++++++++------ 5 files changed, 33 insertions(+), 12 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 302ef0e712..e5ffa6459c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -989,14 +989,14 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise TypeTransformerFailedError("Expected a list") if ListTransformer.is_batchable(python_type): - from flytekit.types.pickle import FlytePickle + from flytekit.types.pickle.pickle import BatchSize, FlytePickle batchSize = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. if get_origin(python_type) is Annotated: for annotation in get_args(python_type)[1:]: - if isinstance(annotation, int): - batchSize = annotation + if isinstance(annotation, BatchSize): + batchSize = annotation.val break lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: @@ -1013,7 +1013,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: from flytekit.types.pickle import FlytePickle batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] - return [item for batch in batch_list for item in batch] + if len(batch_list) > 0 and type(batch_list[0]) is list: + # Make it have backward compatibility. The upstream task may use old version of Flytekit that + # won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. + return [item for batch in batch_list for item in batch] + return batch_list else: st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lits] diff --git a/flytekit/types/pickle/__init__.py b/flytekit/types/pickle/__init__.py index 65604e67bb..e5bd1c056d 100644 --- a/flytekit/types/pickle/__init__.py +++ b/flytekit/types/pickle/__init__.py @@ -9,4 +9,4 @@ FlytePickle """ -from .pickle import FlytePickle +from .pickle import BatchSize, FlytePickle diff --git a/flytekit/types/pickle/pickle.py b/flytekit/types/pickle/pickle.py index 3472dec7e6..3de75b765b 100644 --- a/flytekit/types/pickle/pickle.py +++ b/flytekit/types/pickle/pickle.py @@ -13,6 +13,19 @@ T = typing.TypeVar("T") +class BatchSize: + """ + Flyte-specific object used to wrap the hash function for a specific type + """ + + def __init__(self, val: int): + self._val = val + + @property + def val(self) -> int: + return self._val + + class FlytePickle(typing.Generic[T]): """ This type is only used by flytekit internally. User should not use this type. diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 533ae35c48..88f85c9153 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -16,6 +16,7 @@ ) from flytekit.exceptions.user import FlyteAssertion from flytekit.types.pickle import FlytePickle +from flytekit.types.pickle.pickle import BatchSize def test_create_and_link_node(): @@ -104,7 +105,7 @@ class MyDataclass(object): a: typing.List[str] @task - def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]]): + def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], BatchSize(2)]]): print(a) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 214cc309c2..49997347d0 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -51,7 +51,7 @@ from flytekit.types.file import FileExt, JPEGImageFile from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop from flytekit.types.pickle import FlytePickle -from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer from flytekit.types.schema import FlyteSchema from flytekit.types.schema.types_pandas import PandasDataFrameTransformer from flytekit.types.structured.structured_dataset import StructuredDataset @@ -1584,8 +1584,11 @@ def test_is_batchable(): assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True - assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], 3]) is True - assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), 3]) is True + assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True + assert ( + ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)]) + is True + ) @pytest.mark.parametrize( @@ -1600,12 +1603,12 @@ def test_is_batchable(): # After converting to literal, the result will be # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. # Therefore, the expected list length is [3]. - (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], [3]), + (["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]), # Case 3: Nested list of FlytePickle objects with batch size 2. # After converting to literal, the result will be # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). - ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], [2, 1]), + ([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]), ], ) def test_batch_pickle_list(python_val, python_type, expected_list_length): @@ -1630,6 +1633,6 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length): # to the original input values. This is used to simulate the following case: # @workflow # def wf(): - # data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2] + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] # task1(data=data) # task1(data: typing.List[FlytePickle]) assert pv == python_val From ea7e1aa39f58c724142ad38f16ef212a3f307d6f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 3 Apr 2023 16:26:20 -0700 Subject: [PATCH 17/17] test_batch_size Signed-off-by: Kevin Su --- tests/flytekit/unit/core/test_flyte_pickle.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index 7ceec809b1..c45e200f95 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -14,7 +14,7 @@ from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType from flytekit.tools.translator import get_serializable -from flytekit.types.pickle.pickle import FlytePickle, FlytePickleTransformer +from flytekit.types.pickle.pickle import BatchSize, FlytePickle, FlytePickleTransformer default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -55,6 +55,11 @@ def test_get_literal_type(): ) +def test_batch_size(): + bs = BatchSize(5) + assert bs.val == 5 + + def test_nested(): class Foo(object): def __init__(self, number: int):