diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 557d621dd4..c4f71eb2d6 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler( #. Start a local execution - This means that we're not already in a local workflow execution, which means that we should expect inputs to be native Python values and that we should return Python native values. """ - # Sanity checks - # Only keyword args allowed - if len(args) > 0: - raise _user_exceptions.FlyteAssertion( - f"When calling tasks, only keyword args are supported. " - f"Aborting execution as detected {len(args)} positional args {args}" - ) # Make sure arguments are part of interface for k, v in kwargs.items(): - if k not in cast(SupportsNodeCreation, entity).python_interface.inputs: - raise AssertionError( - f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'" - ) + if k not in entity.python_interface.inputs: + raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'") + + # Check if we have more arguments than expected + if len(args) > len(entity.python_interface.inputs): + raise AssertionError( + f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}" + ) + + # Convert args to kwargs + for arg, input_name in zip(args, entity.python_interface.inputs.keys()): + if input_name in kwargs: + raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'") + kwargs[input_name] = arg ctx = FlyteContextManager.current_context() if ctx.execution_state and ( @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler( child_ctx.execution_state and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED ): - if ( - len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0 - or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0 - ): - output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys()) + if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0: + output_names = list(entity.python_interface.outputs.keys()) if len(output_names) == 0: return VoidPromise(entity.name) vals = [Promise(var, None) for var in output_names] - return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) + return create_task_output(vals, entity.python_interface) else: return None return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler( cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) - expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) + expected_outputs = len(entity.python_interface.outputs) if expected_outputs == 0: if result is None or isinstance(result, VoidPromise): return None @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler( if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): - return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) + return create_native_named_tuple(ctx, result, entity.python_interface) raise AssertionError( f"Expected outputs and actual outputs do not match." f"Result {result}. " - f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}" + f"Python interface: {entity.python_interface}" ) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 2fcf8bbd94..725e3e14fc 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -943,3 +943,126 @@ def wf_with_input() -> typing.Optional[typing.List[int]]: ) assert wf_with_input() == input_val + +def test_positional_args_task(): + arg1 = 5 + arg2 = 6 + ret = 17 + + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def wf_pure_positional_args() -> int: + return t1(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return t1(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_positional_args_workflow(): + arg1 = 5 + arg2 = 6 + ret = 17 + + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def sub_wf(x: int, y: int) -> int: + return t1(x=x, y=y) + + @workflow + def wf_pure_positional_args() -> int: + return sub_wf(arg1, arg2) + + @workflow + def wf_mixed_positional_and_keyword_args() -> int: + return sub_wf(arg1, y=arg2) + + wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args) + wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args) + + arg1_binding = Scalar(primitive=Primitive(integer=arg1)) + arg2_binding = Scalar(primitive=Primitive(integer=arg2)) + output_type = LiteralType(simple=SimpleType.INTEGER) + + assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding + assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding + assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type + + assert wf_pure_positional_args() == ret + assert wf_mixed_positional_and_keyword_args() == ret + +def test_positional_args_chained_tasks(): + @task + def t1(x: int, y: int) -> int: + return x + y * 2 + + @workflow + def wf() -> int: + x = t1(2, y = 3) + y = t1(3, 4) + return t1(x, y = y) + + assert wf() == 30 + +def test_positional_args_task_inputs_from_workflow_args(): + @task + def t1(x: int, y: int, z: int) -> int: + return x + y * 2 + z * 3 + + @workflow + def wf(x: int, y: int) -> int: + return t1(x, y=y, z=3) + + assert wf(1, 2) == 14 + +def test_unexpected_kwargs_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received unexpected keyword argument"): + t1(b=6) + +def test_too_many_positional_args_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Received more arguments than expected"): + t1(1, 2) + +def test_both_positional_and_keyword_args_task_raises_error(): + @task + def t1(a: int) -> int: + return a + + with pytest.raises(AssertionError, match="Got multiple values for argument"): + t1(1, a=2)