Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using union type in subworkflow's input #1626

Merged
merged 3 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import literals as _literals_models
Expand Down Expand Up @@ -618,10 +619,18 @@ def binding_data_from_python_std(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif isinstance(t_value, list):
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")
elif expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
)

elif isinstance(t_value, list):
sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"""
Converts a python value of a given type and expected ``LiteralType`` into a resolved ``Literal`` value.
"""
if isinstance(python_val, Literal):
return python_val
if python_val is None and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
Expand Down
32 changes: 2 additions & 30 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.loggers import logger
Expand Down Expand Up @@ -275,36 +275,8 @@ def compile(self, **kwargs):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.

for k, v in kwargs.items():
if (type(v) == list and isinstance(v[0], Promise)) or (
type(v) == dict and (isinstance(list(v.keys())[0], Promise) or isinstance(list(v.values())[0], Promise))
):
# This is a special case where the user is passing in a list or dict of promises. We don't need to
# do anything here. It happens when a user is passing in a list or dict of outputs to a sub-workflow.
continue
if not isinstance(v, Promise):
t = self.python_interface.inputs[k]
try:
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type))
except TypeTransformerFailedError as exc:
raise TypeError(
f"Failed to convert input argument '{k}' of workflow '{self.name}':\n {exc}"
) from exc

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
# A workflow function may return a task that doesn't return anything
# def wf():
# return t1()
# or it may not return at all
# def wf():
# t1()
# In the former case we get the task's VoidPromise, in the latter we get None
if isinstance(function_outputs, VoidPromise) or function_outputs is None:
if len(self.python_interface.outputs) != 0:
raise FlyteValueException(
Expand All @@ -313,7 +285,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
)
return VoidPromise(self.name)

# Because we should've already returned in the above check, we just raise an error here.
# Because we should've already returned in the above check, we just raise an error here.
if len(self.python_interface.outputs) == 0:
raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.")

Expand Down
88 changes: 68 additions & 20 deletions tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,39 +136,87 @@ def wf(b: int) -> nt:
assert x == (7, 7)


def test_sub_wf_multi_named_tuple_list():
nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)])

def test_sub_wf_varying_types():
@task
def t1(a: int) -> nt:
a = a + 2
return nt(a, a)

@task
def t1l(a: typing.List[int]) -> nt:
return nt(len(a), 5)
def t1l(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int],
d: int,
) -> str:
xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()])
yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v])
if isinstance(c, list):
zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()])
elif isinstance(c, dict):
zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v])
else:
zz = str(c)
return f"First: {xx} Second: {yy} Third: {zz} Int: {d}"

@task
def get_int() -> int:
return 1

@workflow
def subwf(a: typing.List[int]) -> nt:
return t1l(a=a)
def subwf(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]],
d: int,
) -> str:
return t1l(a=a, b=b, c=c, d=d)

@workflow
def subwf2(b: typing.Dict[str, int]):
print(b)
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ds, d=get_int())
return out

wf.compile()
x = wf()
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Int: 1"
)
assert x == expected

@workflow
def wf() -> nt:
ll = [get_int(), get_int()]
out = subwf(a=ll)
subwf2(b={"b": get_int()})
return t1(a=out.named1)
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ll, d=get_int())
return out

x = wf()
print(x)
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Int: 1"
)
assert x == expected


def test_unexpected_outputs():
Expand Down