Skip to content

Commit

Permalink
improve
Browse files Browse the repository at this point in the history
Signed-off-by: Yicheng-Lu-llll <[email protected]>
  • Loading branch information
Yicheng-Lu-llll committed Mar 25, 2023
1 parent ecbb748 commit 0403c48
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
14 changes: 10 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,12 +970,18 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:

@staticmethod
def is_batchable(t: Type):
"""
This function determines whether a given list is batchable or not.
A batchable list consists only FlytePickle objects.
"""
from flytekit.types.pickle import FlytePickle

if t == FlytePickle or (hasattr(t, "__origin__") and t.__origin__ == FlytePickle):
return True
if get_origin(t) is not None:
return any(map(ListTransformer.is_batchable, get_args(t)))
if get_origin(t) is Annotated:
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def wf(i: int, j: int):

@pytest.mark.parametrize(
"input",
[2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], [{"a": {0: "foo"}}] * 5],
[2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5],
)
def test_translate_inputs_to_literals(input):
@dataclass_json
Expand All @@ -105,7 +105,7 @@ class MyDataclass(object):

@task
def t1(
a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2]]
a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]]
):
print(a)

Expand Down
53 changes: 39 additions & 14 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from pandas._testing import assert_frame_equal
from typing_extensions import Annotated
from typing_extensions import Annotated, get_args, get_origin

from flytekit import kwtypes
from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -1577,24 +1577,49 @@ def test_file_ext_with_flyte_file_wrong_type():


@pytest.mark.parametrize(
"python_val, python_type, batch_size",
"python_val, python_type, batch_size, expected_list_length",
[
([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5),
(
[{"a": {0: "foo"}}] * 5,
Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2],
2,
),
# Case 1: List of FlytePickle objects with default batch size.
# (By default, the batch_size is set to the length of the whole list.)
# After converting to literal, the result will be [batched_FlytePickle(5 items)].
# Therefore, the expected list length is [1].
([{"foo"}] * 5, typing.List[FlytePickle], 5, [1]),
# Case 2: List of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], 2, [3]),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], 2, [2, 1]),
],
)
def test_batch_pickle_list(python_val, python_type, batch_size):
from math import ceil

def test_batch_pickle_list(python_val, python_type, batch_size, expected_list_length):
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)
# For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks.
# By default, the batch_size is set to the length of the whole list.
assert len(lv.collection.literals) == ceil(len(python_val) / batch_size)

tmp_lv = lv
tmp_python_val = python_val
for length in expected_list_length:
# Check that after converting to literal, the length of the literal list is equal to:
# - the length of the original list divided by the batch size if not nested
# - the length of the original list if it contains a nested list
assert len(tmp_lv.collection.literals) == length
tmp_lv = tmp_lv.collection.literals[0]
tmp_python_val = tmp_python_val[0]

pv = TypeEngine.to_python_value(ctx, lv, python_type)
# Check that after converting literal to Python value, the result is equal to the original python values.
assert pv == python_val
if get_origin(python_type) is Annotated:
pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0])
# Remove the annotation and check that after converting to Python value, the result is equal
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val

0 comments on commit 0403c48

Please sign in to comment.