Skip to content

Commit

Permalink
fix(bindings): Update according to PR #2460
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 6, 2024
1 parent b45a2aa commit b92519a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
7 changes: 0 additions & 7 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ def binding_data_from_python_std(
t_value: Any,
t_value_type: type,
nodes: List[Node],
is_union_type_variants_expanded: bool = True,
) -> _literals_models.BindingData:
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
Expand All @@ -713,7 +712,6 @@ def binding_data_from_python_std(
# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
# akin to what the Type Engine does when it finds a Union type (see the UnionTransformer), but we can't rely on
# that in this case, because of the mix and match of realized values, and Promises.
elif t_value is not None and is_union_type_variants_expanded and expected_literal_type.union_type is not None:
for i in range(len(expected_literal_type.union_type.variants)):
try:
lt_type = expected_literal_type.union_type.variants[i]
Expand Down Expand Up @@ -780,7 +778,6 @@ def binding_from_python_std(
expected_literal_type: _type_models.LiteralType,
t_value: Any,
t_value_type: type,
is_union_type_variants_expanded: bool = True,
) -> Tuple[_literals_models.Binding, List[Node]]:
nodes: List[Node] = []
binding_data = binding_data_from_python_std(
Expand All @@ -789,7 +786,6 @@ def binding_from_python_std(
t_value,
t_value_type,
nodes,
is_union_type_variants_expanded=is_union_type_variants_expanded,
)
return _literals_models.Binding(var=var_name, binding=binding_data), nodes

Expand Down Expand Up @@ -1082,7 +1078,6 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
is_default_arg_used = False
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation")
if k not in kwargs:
Expand All @@ -1096,7 +1091,6 @@ def create_and_link_node(
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
is_default_arg_used = True
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"
raise _user_exceptions.FlyteAssertion(error_msg)
Expand All @@ -1116,7 +1110,6 @@ def create_and_link_node(
expected_literal_type=var.type,
t_value=v,
t_value_type=interface.inputs[k],
is_union_type_variants_expanded=not is_default_arg_used,
)
bindings.append(b)
nodes.extend(n)
Expand Down
20 changes: 18 additions & 2 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,15 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]:
),
)
assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar(
primitive=Primitive(integer=input_val)
union=Union(
value=Literal(
scalar=Scalar(primitive=Primitive(integer=input_val)),
),
stored_type=LiteralType(
simple=SimpleType.INTEGER,
structure=TypeStructure(tag="int"),
),
),
)

output_type = LiteralType(
Expand Down Expand Up @@ -685,7 +693,15 @@ def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]:
),
)
assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar(
primitive=Primitive(integer=input_val)
union=Union(
value=Literal(
scalar=Scalar(primitive=Primitive(integer=input_val)),
),
stored_type=LiteralType(
simple=SimpleType.INTEGER,
structure=TypeStructure(tag="int"),
),
),
)

output_type = LiteralType(
Expand Down

0 comments on commit b92519a

Please sign in to comment.