Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optional input #989

Merged
merged 20 commits into from
Jun 20, 2022
19 changes: 14 additions & 5 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
from collections import OrderedDict
from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union

from typing_extensions import get_args, get_origin

from flytekit.core import context_manager
from flytekit.core.docstring import Docstring
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteValidationException
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models.literals import Void
from flytekit.types.pickle import FlytePickle

T = typing.TypeVar("T")
Expand Down Expand Up @@ -182,11 +185,17 @@ def transform_inputs_to_parameters(
inputs_with_def = interface.inputs_with_defaults
for k, v in inputs_vars.items():
val, _default = inputs_with_def[k]
required = _default is None
default_lv = None
if _default is not None:
default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val):
from flytekit import Literal, Scalar

literal = Literal(scalar=Scalar(none_type=Void()))
params[k] = _interface_models.Parameter(var=v, default=literal, required=False)
else:
required = _default is None
default_lv = None
if _default is not None:
default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type)
params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required)
return _interface_models.ParameterMap(params)


Expand Down
16 changes: 15 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flytekit.models import types as type_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Primitive
from flytekit.models.types import SimpleType


def translate_inputs_to_literals(
Expand Down Expand Up @@ -881,7 +882,20 @@ def create_and_link_node(
for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if k not in kwargs:
raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
val, _default = interface.inputs_with_defaults[k]
if _default is not None:
raise ValueError(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if not is_optional:
raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type))
else:
continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,11 @@ def literal_map_to_kwargs(
"""
Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task
"""
if len(lm.literals) != len(python_types):
if len(lm.literals) > len(python_types):
raise ValueError(
f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}"
)
return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()}
return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()}

@classmethod
def dict_to_literal_map(
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing

import pytest

from flytekit.core import launch_plan
from flytekit.core.task import task
from flytekit.core.workflow import workflow
Expand Down Expand Up @@ -169,3 +171,30 @@ def my_wf3(a: int = 42) -> (int, str, str, str):
return x, y, u, v

assert my_wf2() == (44, "world-44", "world-5", "world-7")


def test_optional_input():
@task()
def t1(b: typing.Optional[int] = None) -> typing.Optional[int]:
...

@task()
def t2(c: typing.Optional[int] = None) -> typing.Optional[int]:
...

@workflow
def wf(a: typing.Optional[int] = 1) -> typing.Optional[int]:
t1()
return t2(c=a)

assert wf() is None

@task()
def t3(c: typing.Optional[int] = 3) -> typing.Optional[int]:
...

with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"):

@workflow
def wf():
return t3()
11 changes: 11 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
transform_variable_map,
)
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Void
from flytekit.types.file import FlyteFile
from flytekit.types.pickle import FlytePickle

Expand Down Expand Up @@ -199,6 +200,16 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"]
assert our_interface.inputs == {"a": Annotated[int, "some annotation"]}
assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]}

def z(a: typing.Optional[int] = None, b: typing.Optional[str] = None) -> typing.Tuple[int, str]:
...

our_interface = transform_function_to_interface(z)
params = transform_inputs_to_parameters(ctx, our_interface)
assert not params.parameters["a"].required
assert params.parameters["a"].default.scalar.none_type == Void()
assert not params.parameters["b"].required
assert params.parameters["b"].default.scalar.none_type == Void()


def test_parameters_with_docstring():
ctx = context_manager.FlyteContext.current_context()
Expand Down