Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using a list as an input for a subworkflow #1605

Merged
merged 11 commits into from
May 11, 2023
3 changes: 1 addition & 2 deletions Dockerfile.external-plugin-service
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
RUN pip install -U flytekit==$VERSION \
flytekitplugins-bigquery==$VERSION \
RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION

CMD pyflyte serve --port 8000
18 changes: 15 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import literals as _literals_models
Expand Down Expand Up @@ -618,10 +619,21 @@ def binding_data_from_python_std(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif isinstance(t_value, list):
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")
elif expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
)
raise AssertionError(
f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}."
)

elif isinstance(t_value, list):
sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
Expand Down
70 changes: 58 additions & 12 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from typing_extensions import get_args

from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
Expand Down Expand Up @@ -32,14 +35,16 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.models.types import TypeStructure

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand All @@ -49,6 +54,8 @@
flyte_entity=None,
)

T = typing.TypeVar("T")


class WorkflowFailurePolicy(Enum):
"""
Expand Down Expand Up @@ -272,24 +279,63 @@ def execute(self, **kwargs):
def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
if not isinstance(v, Promise):
t = self.python_interface.inputs[k]
def ensure_literal(
self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any
) -> _literal_models.Literal:
"""
This function will attempt to convert a python value to a literal. If the python value is a promise, it will
return the promise's value.
"""
if input_type.union_type is not None:
if python_value is None and UnionTransformer.is_optional_type(py_type):
return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void()))
for i in range(len(input_type.union_type.variants)):
lt_type = input_type.union_type.variants[i]
python_type = get_args(py_type)[i]
try:
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type))
final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value)
lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name)
return _literal_models.Literal(
scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type))
)
except Exception as e:
logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}")
raise TypeError(f"Failed to convert {python_value} to {input_type}")
if isinstance(python_value, list) and input_type.collection_type:
collection_lit_type = input_type.collection_type
collection_py_type = get_args(py_type)[0]
xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx))
elif isinstance(python_value, dict) and input_type.map_value_type:
mapped_lit_type = input_type.map_value_type
mapped_py_type = get_args(py_type)[1]
xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore
return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx))
# It is a scalar, convert to Promise if necessary.
else:
if isinstance(python_value, Promise):
return python_value.val
if not isinstance(python_value, Promise):
try:
res = TypeEngine.to_literal(ctx, python_value, py_type, input_type)
return res
except TypeTransformerFailedError as exc:
raise TypeError(
f"Failed to convert input argument '{k}' of workflow '{self.name}':\n {exc}"
f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}"
) from exc

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
py_type = self.python_interface.inputs[k]
lit_type = self.interface.inputs[k].type
kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v))

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
# A workflow function may return a task that doesn't return anything
# def wf():
Expand Down
8 changes: 3 additions & 5 deletions tests/flytekit/unit/core/test_type_conversion_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ def test_workflow_with_task_error(correct_input):
def test_workflow_with_input_error(incorrect_input):
with pytest.raises(
TypeError,
match=(
r"Encountered error while executing workflow '{}':\n"
r" Failed to convert input argument 'a' of workflow '.+':\n"
r" Expected value of type \<class 'int'\> but got .+ of type"
).format(wf_with_output_error.name),
match=(r"Encountered error while executing workflow '{}':\n" r" Failed to convert input").format(
wf_with_output_error.name
),
):
wf_with_output_error(a=incorrect_input)

Expand Down
83 changes: 83 additions & 0 deletions tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,89 @@ def wf(b: int) -> nt:
assert x == (7, 7)


def test_sub_wf_varying_types():
@task
def t1l(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int],
d: int,
) -> str:
xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()])
yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v])
if isinstance(c, list):
zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()])
elif isinstance(c, dict):
zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v])
else:
zz = str(c)
return f"First: {xx} Second: {yy} Third: {zz} Int: {d}"

@task
def get_int() -> int:
return 1

@workflow
def subwf(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]],
d: int,
) -> str:
return t1l(a=a, b=b, c=c, d=d)

@workflow
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ds, d=get_int())
return out

wf.compile()
x = wf()
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Int: 1"
)
assert x == expected

@workflow
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ll, d=get_int())
return out

x = wf()
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Int: 1"
)
assert x == expected


def test_unexpected_outputs():
@task
def t1(a: int) -> int:
Expand Down