Skip to content

Commit

Permalink
feat(bindings): Task arguments default value binding
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 4, 2024
1 parent db2ff9e commit 6c19346
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
12 changes: 9 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections
import inspect
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast
Expand Down Expand Up @@ -1068,7 +1069,14 @@ def create_and_link_node(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if not is_optional:
if is_optional:
continue
if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None:
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, typing.Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
else:
from flytekit.core.base_task import Task

error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
Expand All @@ -1081,8 +1089,6 @@ def create_and_link_node(
)

raise _user_exceptions.FlyteAssertion(error_msg)
else:
continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
Expand Down
8 changes: 8 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]:
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 0

@task
def t_default_value(a: int = 1) -> int:
return a

p = create_and_link_node(ctx, t_default_value)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1


def test_create_and_link_node_from_remote():
@task
Expand Down
34 changes: 34 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.types import SimpleType
from flytekit.tools.translator import get_serializable
Expand Down Expand Up @@ -85,6 +86,39 @@ def raw_container_wf(val1: int, val2: int) -> int:
assert sum_spec.template.container.image == "alpine"


@pytest.mark.parametrize(
"arg_type,default_arg,input_arg,error",
[
(int | None, None, 100, None),
(int | None, 100, 200, ValueError),
(int, 0, 5, None),
(str, "", "a", None),
(list[int], [], [1, 2, 3], FlyteAssertion),
(dict[str, int], {}, {"a": 1}, FlyteAssertion),
],
)
def test_default_args_task(arg_type, default_arg, input_arg, error):
@task
def t1(a: arg_type = default_arg) -> arg_type:
return a

@workflow
def wf_no_input() -> arg_type:
return t1()

@workflow
def wf_with_input() -> arg_type:
return t1(a=input_arg)

if error:
with pytest.raises(error):
wf_no_input()
else:
assert wf_no_input() == default_arg

assert wf_with_input() == input_arg


def test_serialization_branch_complex():
@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
Expand Down

0 comments on commit 6c19346

Please sign in to comment.