Skip to content

Commit

Permalink
feat(bindings): Don't expand union type if default arg is used
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 4, 2024
1 parent a0e679c commit c9cb9df
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 6 deletions.
16 changes: 14 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def binding_data_from_python_std(
t_value: Any,
t_value_type: type,
nodes: List[Node],
is_union_type_variants_expanded: bool = True,
) -> _literals_models.BindingData:
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
Expand All @@ -702,7 +703,7 @@ def binding_data_from_python_std(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif t_value is not None and expected_literal_type.union_type is not None:
elif t_value is not None and is_union_type_variants_expanded and expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
Expand Down Expand Up @@ -769,9 +770,17 @@ def binding_from_python_std(
expected_literal_type: _type_models.LiteralType,
t_value: Any,
t_value_type: type,
is_union_type_variants_expanded: bool = True,
) -> Tuple[_literals_models.Binding, List[Node]]:
nodes: List[Node] = []
binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type, nodes)
binding_data = binding_data_from_python_std(
ctx,
expected_literal_type,
t_value,
t_value_type,
nodes,
is_union_type_variants_expanded=is_union_type_variants_expanded,
)
return _literals_models.Binding(var=var_name, binding=binding_data), nodes


Expand Down Expand Up @@ -1063,6 +1072,7 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
is_default_arg_used = False
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation")
if k not in kwargs:
Expand All @@ -1076,6 +1086,7 @@ def create_and_link_node(
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
is_default_arg_used = True
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
raise _user_exceptions.FlyteAssertion(error_msg)
Expand All @@ -1095,6 +1106,7 @@ def create_and_link_node(
expected_literal_type=var.type,
t_value=v,
t_value_type=interface.inputs[k],
is_union_type_variants_expanded=not is_default_arg_used,
)
bindings.append(b)
nodes.extend(n)
Expand Down
130 changes: 126 additions & 4 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.core.task import task
from flytekit.core.workflow import workflow
from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec
from flytekit.exceptions.user import FlyteAssertion
from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.literals import (
BindingData,
Expand Down Expand Up @@ -586,7 +586,7 @@ def wf_with_sub_wf() -> tuple[str, str]:
assert wf_with_sub_wf() == (default_val, input_val)


def test_default_args_task_optional_type_default_none():
def test_default_args_task_optional_int_type_default_none():
default_val = None
input_val = 100

Expand Down Expand Up @@ -648,7 +648,7 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]:
assert wf_with_sub_wf() == (default_val, input_val)


def test_default_args_task_optional_type_default_int():
def test_default_args_task_optional_int_type_default_int():
default_val = 10
input_val = 100

Expand All @@ -672,7 +672,17 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]:
wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)

assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar(
primitive=Primitive(integer=default_val)
union=Union(
value=Literal(
scalar=Scalar(
primitive=Primitive(integer=default_val),
),
),
stored_type=LiteralType(
simple=SimpleType.INTEGER,
structure=TypeStructure(tag="int"),
),
),
)
assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar(
primitive=Primitive(integer=input_val)
Expand Down Expand Up @@ -805,3 +815,115 @@ def wf_with_input() -> dict[str, int]:
)

assert wf_with_input() == input_val


def test_default_args_task_optional_list_type_default_none():
default_val = None
input_val = [1, 2, 3]

@task
def t1(a: typing.Optional[typing.List[int]] = default_val) -> typing.Optional[typing.List[int]]:
return a

@workflow
def wf_no_input() -> typing.Optional[typing.List[int]]:
return t1()

@workflow
def wf_with_input() -> typing.Optional[typing.List[int]]:
return t1(a=input_val)

@workflow
def wf_with_sub_wf() -> tuple[typing.Optional[typing.List[int]], typing.Optional[typing.List[int]]]:
return (wf_no_input(), wf_with_input())

wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input)
wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)

assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar(
union=Union(
value=Literal(
scalar=Scalar(
none_type=Void(),
),
),
stored_type=LiteralType(
simple=SimpleType.NONE,
structure=TypeStructure(tag="none"),
),
),
)
assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection(
bindings=[
BindingData(scalar=Scalar(primitive=Primitive(integer=1))),
BindingData(scalar=Scalar(primitive=Primitive(integer=2))),
BindingData(scalar=Scalar(primitive=Primitive(integer=3))),
]
)

output_type = LiteralType(
union_type=UnionType(
[
LiteralType(
collection_type=LiteralType(simple=SimpleType.INTEGER),
structure=TypeStructure(tag="Typed List"),
),
LiteralType(
simple=SimpleType.NONE,
structure=TypeStructure(tag="none"),
),
]
)
)
assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type
assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type

assert wf_no_input() == default_val
assert wf_with_input() == input_val
assert wf_with_sub_wf() == (default_val, input_val)


def test_default_args_task_optional_list_type_default_list():
input_val = [1, 2, 3]

@task
def t1(a: typing.Optional[typing.List[int]] = []) -> typing.Optional[typing.List[int]]:
return a

@workflow
def wf_no_input() -> typing.Optional[typing.List[int]]:
return t1()

@workflow
def wf_with_input() -> typing.Optional[typing.List[int]]:
return t1(a=input_val)

with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"):
get_serializable(OrderedDict(), serialization_settings, wf_no_input)

wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input)

assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection(
bindings=[
BindingData(scalar=Scalar(primitive=Primitive(integer=1))),
BindingData(scalar=Scalar(primitive=Primitive(integer=2))),
BindingData(scalar=Scalar(primitive=Primitive(integer=3))),
]
)

assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType(
union_type=UnionType(
[
LiteralType(
collection_type=LiteralType(simple=SimpleType.INTEGER),
structure=TypeStructure(tag="Typed List"),
),
LiteralType(
simple=SimpleType.NONE,
structure=TypeStructure(tag="none"),
),
]
)
)

assert wf_with_input() == input_val

0 comments on commit c9cb9df

Please sign in to comment.