Skip to content

Commit

Permalink
test(default-args): Add more tests according to Yee's recommendation
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed Jun 4, 2024
1 parent ab9e6a6 commit a0e679c
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 78 deletions.
37 changes: 18 additions & 19 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import collections
import inspect
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args

from google.protobuf import struct_pb2 as _struct
from typing_extensions import Protocol, get_args
from typing_extensions import Protocol

from flytekit.core import constants as _common_constants
from flytekit.core import context_manager as _flyte_context
Expand All @@ -24,7 +23,13 @@
)
from flytekit.core.interface import Interface
from flytekit.core.node import Node
from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError
from flytekit.core.type_engine import (
DictTransformer,
ListTransformer,
TypeEngine,
TypeTransformerFailedError,
UnionTransformer,
)
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlytePromiseAttributeResolveException
from flytekit.loggers import logger
Expand Down Expand Up @@ -1058,27 +1063,21 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation")
if k not in kwargs:
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
val, _default = interface.inputs_with_defaults[k]
if _default is not None:
raise ValueError(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if is_optional:
continue
if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None:
# interface.inputs_with_defaults[k][0] is the type of the default argument
# interface.inputs_with_defaults[k][1] is the value of the default argument
if k in interface.inputs_with_defaults and (
interface.inputs_with_defaults[k][1] is not None
or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0])
):
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, typing.Hashable):
if not isinstance(default_val, Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
kwargs[k] = default_val
else:
error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}"

raise _user_exceptions.FlyteAssertion(error_msg)
v = kwargs[k]
# This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,7 +1555,7 @@ def __init__(self):
super().__init__("Typed Union", typing.Union)

@staticmethod
def is_optional_type(t: Type[T]) -> bool:
def is_optional_type(t: Type) -> bool:
"""Return True if `t` is a Union or Optional type."""
return _is_union_type(t) or type(None) in get_args(t)

Expand Down
14 changes: 0 additions & 14 deletions tests/flytekit/unit/core/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, List, NamedTuple, Optional, Union

import pytest

from flytekit.core import launch_plan
from flytekit.core.task import task
from flytekit.core.workflow import workflow
Expand Down Expand Up @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]:
return t2(a=a)

assert wf() is None

with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"):

@task()
def t3(c: Optional[int] = 3) -> Optional[int]:
...

@workflow
def wf():
return t3()

wf()
37 changes: 37 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,43 @@ def ranged_int_to_str(a: int) -> typing.List[str]:
assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"]


@pytest.mark.parametrize(
"input_val,output_val",
[
(4, 0),
(5, 5),
],
)
def test_dynamic_local_default_args_task(input_val, output_val):
@task
def t1(a: int = 0) -> int:
return a

@dynamic
def dt(a: int) -> int:
if a % 2 == 0:
return t1()
return t1(a=a)

assert dt(a=input_val) == output_val

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val})
dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec.nodes) == 1
assert len(dynamic_job_spec.tasks) == 1
assert dynamic_job_spec.nodes[0].inputs[0].binding.scalar.primitive is not None


def test_nested_dynamic_local():
@task
def t1(a: int) -> str:
Expand Down
8 changes: 0 additions & 8 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]:

p = create_and_link_node(ctx, t2)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 0

@task
def t_default_value(a: int = 1) -> int:
return a

p = create_and_link_node(ctx, t_default_value)
assert p.ref.var == "o0"
assert len(p.ref.node.bindings) == 1


Expand Down
Loading

0 comments on commit a0e679c

Please sign in to comment.