Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core feature] Convert List[Any] to a single pickle file #1535

Merged
2 changes: 2 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def extract_value(
if len(input_val) == 0:
raise
sub_type = type(input_val[0])
if ListTransformer.is_batchable(python_type):
Yicheng-Lu-llll marked this conversation as resolved.
Show resolved Hide resolved
return TypeEngine.to_literal(ctx, input_val, python_type, lt)
literal_list = [extract_value(ctx, v, sub_type, lt.collection_type) for v in input_val]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=literal_list))
elif isinstance(input_val, dict):
Expand Down
45 changes: 39 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,22 +968,55 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
except Exception as e:
raise ValueError(f"Type of Generic List type is not supported, {e}")

@staticmethod
def is_batchable(t: Type):
Yicheng-Lu-llll marked this conversation as resolved.
Show resolved Hide resolved
"""
This function determines whether a given list is batchable or not.
A batchable list consists only FlytePickle objects.
"""
from flytekit.types.pickle import FlytePickle

if get_origin(t) is Annotated:
return ListTransformer.is_batchable(get_args(t)[0])
if get_origin(t) is list:
subtype = get_args(t)[0]
if subtype == FlytePickle or (hasattr(subtype, "__origin__") and subtype.__origin__ == FlytePickle):
return True
return False

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
from flytekit.types.pickle import FlytePickle
Yicheng-Lu-llll marked this conversation as resolved.
Show resolved Hide resolved

if type(python_val) != list:
raise TypeTransformerFailedError("Expected a list")

t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
if ListTransformer.is_batchable(python_type):
batchSize = len(python_val) # default batch size
Yicheng-Lu-llll marked this conversation as resolved.
Show resolved Hide resolved
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated:
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, int):
batchSize = annotation
break
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore
else:
t = self.get_sub_type(python_type)
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore
return Literal(collection=LiteralCollection(literals=lit_list))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore
from flytekit.types.pickle import FlytePickle
Yicheng-Lu-llll marked this conversation as resolved.
Show resolved Hide resolved

try:
lits = lv.collection.literals
except AttributeError:
raise TypeTransformerFailedError()

st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]
if self.is_batchable(expected_python_type):
batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits]
return [item for batch in batch_list for item in batch]
else:
st = self.get_sub_type(expected_python_type)
return [TypeEngine.to_python_value(ctx, x, st) for x in lits]

def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore
if literal_type.collection_type:
Expand Down Expand Up @@ -1044,7 +1077,7 @@ def _are_types_castable(upstream: LiteralType, downstream: LiteralType) -> bool:
if len(ucols) != len(dcols):
return False

for (u, d) in zip(ucols, dcols):
for u, d in zip(ucols, dcols):
if u.name != d.name:
return False

Expand Down
6 changes: 4 additions & 2 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from dataclasses_json import dataclass_json
from typing_extensions import Annotated

from flytekit import LaunchPlan, task, workflow
from flytekit.core import context_manager
Expand All @@ -14,6 +15,7 @@
translate_inputs_to_literals,
)
from flytekit.exceptions.user import FlyteAssertion
from flytekit.types.pickle import FlytePickle


def test_create_and_link_node():
Expand Down Expand Up @@ -92,7 +94,7 @@ def wf(i: int, j: int):

@pytest.mark.parametrize(
"input",
[2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3]],
[2.0, {"i": 1, "a": ["h", "e"]}, [1, 2, 3], ["foo"] * 5],
)
def test_translate_inputs_to_literals(input):
@dataclass_json
Expand All @@ -102,7 +104,7 @@ class MyDataclass(object):
a: typing.List[str]

@task
def t1(a: typing.Union[float, typing.List[int], MyDataclass]):
def t1(a: typing.Union[float, typing.List[int], MyDataclass, Annotated[typing.List[FlytePickle], 2]]):
print(a)

ctx = context_manager.FlyteContext.current_context()
Expand Down
49 changes: 48 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from pandas._testing import assert_frame_equal
from typing_extensions import Annotated
from typing_extensions import Annotated, get_args, get_origin

from flytekit import kwtypes
from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -1574,3 +1574,50 @@ def test_file_ext_with_flyte_file_wrong_type():
with pytest.raises(ValueError) as e:
FlyteFile[WRONG_TYPE]
assert str(e.value) == "Underlying type of File Extension must be of type <str>"


@pytest.mark.parametrize(
"python_val, python_type, expected_list_length",
[
# Case 1: List of FlytePickle objects with default batch size.
# (By default, the batch_size is set to the length of the whole list.)
# After converting to literal, the result will be [batched_FlytePickle(5 items)].
# Therefore, the expected list length is [1].
([{"foo"}] * 5, typing.List[FlytePickle], [1]),
# Case 2: List of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)].
# Therefore, the expected list length is [3].
(["foo"] * 5, Annotated[typing.List[FlytePickle], HashMethod(function=str), 2], [3]),
# Case 3: Nested list of FlytePickle objects with batch size 2.
# After converting to literal, the result will be
# [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]]
# Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched).
([["foo", "foo", "foo"]] * 2, typing.List[Annotated[typing.List[FlytePickle], 3]], [2, 1]),
],
)
def test_batch_pickle_list(python_val, python_type, expected_list_length):
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)

tmp_lv = lv
for length in expected_list_length:
# Check that after converting to literal, the length of the literal list is equal to:
# - the length of the original list divided by the batch size if not nested
# - the length of the original list if it contains a nested list
assert len(tmp_lv.collection.literals) == length
tmp_lv = tmp_lv.collection.literals[0]

pv = TypeEngine.to_python_value(ctx, lv, python_type)
# Check that after converting literal to Python value, the result is equal to the original python values.
assert pv == python_val
if get_origin(python_type) is Annotated:
pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0])
# Remove the annotation and check that after converting to Python value, the result is equal
# to the original input values. This is used to simulate the following case:
# @workflow
# def wf():
# data = task0() # task0() -> Annotated[typing.List[FlytePickle], 2]
# task1(data=data) # task1(data: typing.List[FlytePickle])
assert pv == python_val