Skip to content

Commit

Permalink
feat(bindings): Task arguments default value binding
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed May 16, 2024
1 parent 76fb7c3 commit 49f8c27
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,11 @@ def create_and_link_node(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if not is_optional:
if is_optional:
continue
if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None:
kwargs[k] = interface.inputs_with_defaults[k][1]
else:
from flytekit.core.base_task import Task

error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
Expand All @@ -1081,8 +1085,6 @@ def create_and_link_node(
)

raise _user_exceptions.FlyteAssertion(error_msg)
else:
continue
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
# Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]:
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 0

@task
def t_default_value(a: int = 1) -> typing.Optional[int]:
return a

p = create_and_link_node(ctx, t_default_value)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1

@workflow
def wf_default_no_input() -> typing.Optional[int]:
return t_default_value()

@workflow
def wf_default_with_input() -> typing.Optional[int]:
return t_default_value(a=2)

assert wf_default_no_input() == 1
assert wf_default_with_input() == 2


def test_create_and_link_node_from_remote():
@task
Expand Down

0 comments on commit 49f8c27

Please sign in to comment.