Skip to content

Commit

Permalink
test(default-args): Add more tests according to Yee's recommendation
Browse files Browse the repository at this point in the history
Resolves: flyteorg/flyte#5321
Signed-off-by: Chi-Sheng Liu <[email protected]>
  • Loading branch information
MortalHappiness committed May 25, 2024
1 parent db89a06 commit 6042fcf
Show file tree
Hide file tree
Showing 4 changed files with 373 additions and 30 deletions.
16 changes: 3 additions & 13 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,20 +1058,10 @@ def create_and_link_node(

for k in sorted(interface.inputs):
var = typed_interface.inputs[k]
if var.type.simple == SimpleType.NONE:
raise TypeError("Arguments do not have type annotation or the type annotation is None")
if k not in kwargs:
is_optional = False
if var.type.union_type:
for variant in var.type.union_type.variants:
if variant.simple == SimpleType.NONE:
val, _default = interface.inputs_with_defaults[k]
if _default is not None:
raise ValueError(
f"The default value for the optional type must be None, but got {_default}"
)
is_optional = True
if is_optional:
continue
if k in interface.inputs_with_defaults and interface.inputs_with_defaults[k][1] is not None:
if k in interface.inputs_with_defaults:
default_val = interface.inputs_with_defaults[k][1]
if not isinstance(default_val, typing.Hashable):
raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument")
Expand Down
22 changes: 22 additions & 0 deletions kubeflow_tf_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from flytekitplugins.kftensorflow import PS, Chief, Evaluator, TfJob, Worker

from flytekit import Resources, task

task_config = TfJob(
worker=Worker(replicas=2),
chief=Chief(replicas=1),
ps=PS(replicas=1),
evaluator=Evaluator(replicas=1),
)


@task(
task_config=task_config,
requests=Resources(cpu="1"),
)
def my_tensorflow_task(x: int, y: str) -> str:
return f"{x=}, {y=}"


if __name__ == "__main__":
print(my_tensorflow_task(x=10, y="hello"))
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ def ranged_int_to_str(a: int) -> typing.List[str]:
assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"]


@pytest.mark.parametrize(
"input_val,output_val",
[
(4, 0),
(5, 5),
],
)
def test_dynamic_local_default_args_task(input_val, output_val):
@task
def t1(a: int = 0) -> int:
return a

@dynamic
def dt(a: int) -> int:
if a % 2 == 0:
return t1()
return t1(a=a)

assert dt(a=input_val) == output_val

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val})
dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map)
assert len(dynamic_job_spec.nodes) == 1
assert len(dynamic_job_spec.tasks) == 1


def test_nested_dynamic_local():
@task
def t1(a: int) -> str:
Expand Down
Loading

0 comments on commit 6042fcf

Please sign in to comment.