Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional input #989

Merged
merged 20 commits into from
Jun 20, 2022
10 changes: 9 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.types import SimpleType


def translate_inputs_to_literals(
Expand Down Expand Up @@ -871,7 +872,14 @@ def create_and_link_node(
for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if k not in kwargs:
raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
kwargs[k] = None
is_optional = True
if not is_optional:
raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from flytekit.exceptions import user as user_exceptions
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals
from flytekit.models import types as _type_models
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel
from flytekit.models.core import types as _core_types
Expand Down Expand Up @@ -737,6 +738,11 @@ def literal_map_to_kwargs(
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
"""
# Assign default literal value (void) if python type is an optional type
for k, v in python_types.items():
if k not in lm.literals and get_origin(v) is typing.Union and type(None) in get_args(v):
lm.literals[k] = Literal(scalar=literals.Scalar(none_type=literals.Void()))

if len(lm.literals) != len(python_types):
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
Expand Down
11 changes: 8 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.types import SimpleType

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID,
Expand Down Expand Up @@ -243,10 +244,14 @@ def execute(self, **kwargs):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
# This is done to support the invariant that Workflow local executions always work with Promise objects
# holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value.
for k, v in kwargs.items():
if not isinstance(v, Promise):
for k, v in self.interface.inputs.items():
if k not in kwargs and v.type.union_type:
for variant in v.type.union_type.variants:
if variant.simple == SimpleType.NONE:
kwargs[k] = None
if not isinstance(kwargs[k], Promise):
t = self.python_interface.inputs[k]
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type))
kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, kwargs[k], t, v.type))

# The output of this will always be a combination of Python native values and Promises containing Flyte
# Literals.
Expand Down
16 changes: 16 additions & 0 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,19 @@ def my_wf3(a: int = 42) -> (int, str, str, str):
return x, y, u, v

assert my_wf2() == (44, "world-44", "world-5", "world-7")


def test_optional_input():
@task()
def t1(b: typing.Optional[int]) -> str:
return str(b)

@task()
def t2(c: str) -> str:
return c

@workflow
def wf(a: typing.Optional[int]) -> str:
return t2(c=t1())

assert wf() == str(None)
pingsutw marked this conversation as resolved.
Show resolved Hide resolved