Skip to content

Commit

Permalink
Add support for using a list as an input for a subworkflow (#1605)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and eapolinario committed May 16, 2023
1 parent 7409889 commit 396080c
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 23 deletions.
3 changes: 1 addition & 2 deletions Dockerfile.external-plugin-service
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ MAINTAINER Flyte Team <[email protected]>
LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit

ARG VERSION
RUN pip install -U flytekit==$VERSION \
flytekitplugins-bigquery==$VERSION \
RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION

CMD pyflyte serve --port 8000
20 changes: 16 additions & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import literals as _literals_models
Expand Down Expand Up @@ -612,11 +613,22 @@ def binding_data_from_python_std(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif isinstance(t_value, list):
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")
elif expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
)
raise AssertionError(
f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}."
)

sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
elif isinstance(t_value, list):
sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value
Expand Down
70 changes: 58 additions & 12 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from typing_extensions import get_args

from flytekit.core import constants as _common_constants
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
Expand Down Expand Up @@ -32,14 +35,16 @@
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer
from flytekit.exceptions import scopes as exception_scopes
from flytekit.exceptions.user import FlyteValidationException, FlyteValueException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.documentation import Description, Documentation
from flytekit.models.types import TypeStructure

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand All @@ -49,6 +54,8 @@
flyte_entity=None,
)

T = typing.TypeVar("T")


class WorkflowFailurePolicy(Enum):
"""
Expand Down Expand Up @@ -270,24 +277,63 @@ def execute(self, **kwargs):
def compile(self, **kwargs):
pass

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
if not isinstance(v, Promise):
t = self.python_interface.inputs[k]
def ensure_literal(
self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any
) -> _literal_models.Literal:
"""
This function will attempt to convert a python value to a literal. If the python value is a promise, it will
return the promise's value.
"""
if input_type.union_type is not None:
if python_value is None and UnionTransformer.is_optional_type(py_type):
return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void()))
for i in range(len(input_type.union_type.variants)):
lt_type = input_type.union_type.variants[i]
python_type = get_args(py_type)[i]
try:
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type))
final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value)
lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name)
return _literal_models.Literal(
scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type))
)
except Exception as e:
logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}")
raise TypeError(f"Failed to convert {python_value} to {input_type}")
if isinstance(python_value, list) and input_type.collection_type:
collection_lit_type = input_type.collection_type
collection_py_type = get_args(py_type)[0]
xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value]
return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx))
elif isinstance(python_value, dict) and input_type.map_value_type:
mapped_lit_type = input_type.map_value_type
mapped_py_type = get_args(py_type)[1]
xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore
return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx))
# It is a scalar, convert to Promise if necessary.
else:
if isinstance(python_value, Promise):
return python_value.val
if not isinstance(python_value, Promise):
try:
res = TypeEngine.to_literal(ctx, python_value, py_type, input_type)
return res
except TypeTransformerFailedError as exc:
raise TypeError(
f"Failed to convert input argument '{k}' of workflow '{self.name}':\n {exc}"
f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}"
) from exc

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
py_type = self.python_interface.inputs[k]
lit_type = self.interface.inputs[k].type
kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v))

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
self.compile()
function_outputs = self.execute(**kwargs)

# First handle the empty return case.
# A workflow function may return a task that doesn't return anything
# def wf():
Expand Down
8 changes: 3 additions & 5 deletions tests/flytekit/unit/core/test_type_conversion_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ def test_workflow_with_task_error(correct_input):
def test_workflow_with_input_error(incorrect_input):
with pytest.raises(
TypeError,
match=(
r"Encountered error while executing workflow '{}':\n"
r" Failed to convert input argument 'a' of workflow '.+':\n"
r" Expected value of type \<class 'int'\> but got .+ of type"
).format(wf_with_output_error.name),
match=(r"Encountered error while executing workflow '{}':\n" r" Failed to convert input").format(
wf_with_output_error.name
),
):
wf_with_output_error(a=incorrect_input)

Expand Down
83 changes: 83 additions & 0 deletions tests/flytekit/unit/core/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,89 @@ def wf(b: int) -> nt:
assert x == (7, 7)


def test_sub_wf_varying_types():
@task
def t1l(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int],
d: int,
) -> str:
xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()])
yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v])
if isinstance(c, list):
zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()])
elif isinstance(c, dict):
zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v])
else:
zz = str(c)
return f"First: {xx} Second: {yy} Third: {zz} Int: {d}"

@task
def get_int() -> int:
return 1

@workflow
def subwf(
a: typing.List[typing.Dict[str, typing.List[int]]],
b: typing.Dict[str, typing.List[int]],
c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]],
d: int,
) -> str:
return t1l(a=a, b=b, c=c, d=d)

@workflow
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ds, d=get_int())
return out

wf.compile()
x = wf()
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Int: 1"
)
assert x == expected

@workflow
def wf() -> str:
ds = [
{"first_map_a": [42], "first_map_b": [get_int(), 2]},
{
"second_map_c": [33],
"second_map_d": [9, 99],
},
]
ll = {
"ll_1": [get_int(), get_int(), get_int()],
"ll_2": [4, 5, 6],
}
out = subwf(a=ds, b=ll, c=ll, d=get_int())
return out

x = wf()
expected = (
"First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] "
"Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 "
"Int: 1"
)
assert x == expected


def test_unexpected_outputs():
@task
def t1(a: int) -> int:
Expand Down

0 comments on commit 396080c

Please sign in to comment.