diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index fdd47d1741..a3b1b80dd8 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,12 +7,15 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing_extensions import get_args, get_origin + from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteValidationException from flytekit.loggers import logger from flytekit.models import interface as _interface_models +from flytekit.models.literals import Void from flytekit.types.pickle import FlytePickle T = typing.TypeVar("T") @@ -182,11 +185,17 @@ def transform_inputs_to_parameters( inputs_with_def = interface.inputs_with_defaults for k, v in inputs_vars.items(): val, _default = inputs_with_def[k] - required = _default is None - default_lv = None - if _default is not None: - default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) - params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) + if _default is None and get_origin(val) is typing.Union and type(None) in get_args(val): + from flytekit import Literal, Scalar + + literal = Literal(scalar=Scalar(none_type=Void())) + params[k] = _interface_models.Parameter(var=v, default=literal, required=False) + else: + required = _default is None + default_lv = None + if _default is not None: + default_lv = TypeEngine.to_literal(ctx, _default, python_type=interface.inputs[k], expected=v.type) + params[k] = _interface_models.Parameter(var=v, default=default_lv, required=required) return _interface_models.ParameterMap(params) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index e5877e21fe..6b0aafdc1b 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -23,6 +23,7 @@ from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Primitive +from flytekit.models.types import SimpleType def translate_inputs_to_literals( @@ -881,7 +882,20 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] if k not in kwargs: - raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + is_optional = False + if var.type.union_type: + for variant in var.type.union_type.variants: + if variant.simple == SimpleType.NONE: + val, _default = interface.inputs_with_defaults[k] + if _default is not None: + raise ValueError( + f"The default value for the optional type must be None, but got {_default}" + ) + is_optional = True + if not is_optional: + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) + else: + continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 60c42ac339..d4ca18582f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -737,11 +737,11 @@ def literal_map_to_kwargs( """ Given a ``LiteralMap`` (usually an input into a task - intermediate), convert to kwargs for the task """ - if len(lm.literals) != len(python_types): + if len(lm.literals) > len(python_types): raise ValueError( f"Received more input values {len(lm.literals)}" f" than allowed by the input spec {len(python_types)}" ) - return {k: TypeEngine.to_python_value(ctx, lm.literals[k], v) for k, v in python_types.items()} + return {k: TypeEngine.to_python_value(ctx, lm.literals[k], python_types[k]) for k, v in lm.literals.items()} @classmethod def dict_to_literal_map( diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 37f5d10195..3963c77c8d 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,4 +1,6 @@ -import typing +from typing import Dict, List, NamedTuple, Optional, Union + +import pytest from flytekit.core import launch_plan from flytekit.core.task import task @@ -8,7 +10,7 @@ def test_wf1_with_subwf(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -33,7 +35,7 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = typing.NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", sub_int=int) @task def t1(a: int) -> nt: @@ -68,7 +70,7 @@ def my_wf(a: int) -> int: def test_lp_default_handling(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -130,7 +132,7 @@ def my_wf2(a: int, b: int = 42) -> (str, str, int, int): def test_wf1_with_lp_node(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> NamedTuple("OutputsBC", t1_int_output=int, c=str): a = a + 2 return a, "world-" + str(a) @@ -169,3 +171,30 @@ def my_wf3(a: int = 42) -> (int, str, str, str): return x, y, u, v assert my_wf2() == (44, "world-44", "world-5", "world-7") + + +def test_optional_input(): + @task() + def t1(a: Optional[int] = None, b: Optional[List[int]] = None, c: Optional[Dict[str, int]] = None) -> Optional[int]: + ... + + @task() + def t2(a: Union[int, Optional[List[int]], None] = None) -> Union[int, Optional[List[int]], None]: + ... + + @workflow + def wf(a: Optional[int] = 1) -> Optional[int]: + t1() + return t2(a=a) + + assert wf() is None + + with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): + + @task() + def t3(c: Optional[int] = 3) -> Optional[int]: + ... + + @workflow + def wf(): + return t3() diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 2317bab1c2..62c45c346e 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -12,6 +12,7 @@ transform_variable_map, ) from flytekit.models.core import types as _core_types +from flytekit.models.literals import Void from flytekit.types.file import FlyteFile from flytekit.types.pickle import FlytePickle @@ -199,6 +200,20 @@ def z(a: Annotated[int, "some annotation"]) -> Annotated[int, "some annotation"] assert our_interface.inputs == {"a": Annotated[int, "some annotation"]} assert our_interface.outputs == {"o0": Annotated[int, "some annotation"]} + def z( + a: typing.Optional[int] = None, b: typing.Optional[str] = None, c: typing.Union[typing.List[int], None] = None + ) -> typing.Tuple[int, str]: + ... + + our_interface = transform_function_to_interface(z) + params = transform_inputs_to_parameters(ctx, our_interface) + assert not params.parameters["a"].required + assert params.parameters["a"].default.scalar.none_type == Void() + assert not params.parameters["b"].required + assert params.parameters["b"].default.scalar.none_type == Void() + assert not params.parameters["c"].required + assert params.parameters["c"].default.scalar.none_type == Void() + def test_parameters_with_docstring(): ctx = context_manager.FlyteContext.current_context()