Skip to content

Commit

Permalink
Add BatchSize
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Apr 3, 2023
1 parent 8a21673 commit 0977021
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 12 deletions.
12 changes: 8 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,14 +989,14 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
raise TypeTransformerFailedError("Expected a list")

if ListTransformer.is_batchable(python_type):
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import BatchSize, FlytePickle

batchSize = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated:
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, int):
batchSize = annotation
if isinstance(annotation, BatchSize):
batchSize = annotation.val
break
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore
else:
Expand All @@ -1013,7 +1013,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
from flytekit.types.pickle import FlytePickle

batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
return [item for batch in batch_list for item in batch]
if len(batch_list) > 0 and type(batch_list[0]) is list:
# Make it have backward compatibility. The upstream task may use old version of Flytekit that
# won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first.
return [item for batch in batch_list for item in batch]
return batch_list
else:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]
Expand Down
2 changes: 1 addition & 1 deletion flytekit/types/pickle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
FlytePickle
"""

from .pickle import FlytePickle
from .pickle import BatchSize, FlytePickle
13 changes: 13 additions & 0 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
T = typing.TypeVar("T")


class BatchSize:
"""
Flyte-specific object used to wrap the hash function for a specific type
"""

def __init__(self, val: int):
self._val = val

@property
def val(self) -> int:
return self._val


class FlytePickle(typing.Generic[T]):
"""
This type is only used by flytekit internally. User should not use this type.
Expand Down
3 changes: 2 additions & 1 deletion tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from flytekit.exceptions.user import FlyteAssertion
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import BatchSize


def test_create_and_link_node():
Expand Down Expand Up @@ -104,7 +105,7 @@ class MyDataclass(object):
a: typing.List[str]

@task
def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]]):
def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], BatchSize(2)]]):
print(a)

ctx = context_manager.FlyteContext.current_context()
Expand Down
15 changes: 9 additions & 6 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from flytekit.types.file import FileExt, JPEGImageFile
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer, noop
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.schema.types_pandas import PandasDataFrameTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset
Expand Down Expand Up @@ -1584,8 +1584,11 @@ def test_is_batchable():
assert ListTransformer.is_batchable(typing.List[typing.List[FlytePickle]]) is False

assert ListTransformer.is_batchable(typing.List[FlytePickle]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], 3]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), 3]) is True
assert ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], BatchSize(3)]) is True
assert (
ListTransformer.is_batchable(Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(3)])
is True
)


@pytest.mark.parametrize(
Expand All @@ -1600,12 +1603,12 @@ def test_is_batchable():
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], [3]),
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], [3]),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], [2, 1]),
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], [2, 1]),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
Expand All @@ -1630,6 +1633,6 @@ def test_batch_pickle_list(python_val, python_type, expected_list_length):
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2]
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val

0 comments on commit 0977021

Please sign in to comment.