Skip to content

Commit

Permalink
Fix map_task sensitive to argument order (flyteorg#1888)
Browse files Browse the repository at this point in the history
* Fix _raw_execute for getting correct len of input value.

Signed-off-by: Chao-Heng Lee <[email protected]>

* Add test.

Signed-off-by: Chao-Heng Lee <[email protected]>

* also update with array_node_map_task.

Signed-off-by: Chao-Heng Lee <[email protected]>

* rename test.

Signed-off-by: Chao-Heng Lee <[email protected]>

---------

Signed-off-by: Chao-Heng Lee <[email protected]>
  • Loading branch information
chaohengstudent authored and ringohoffman committed Nov 24, 2023
1 parent f213654 commit 08872b5
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 12 deletions.
14 changes: 8 additions & 6 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,15 @@ def _raw_execute(self, **kwargs) -> Any:
outputs_expected = False
outputs = []

any_input_key = (
list(self.python_function_task.interface.inputs.keys())[0]
if self.python_function_task.interface.inputs.items() is not None
else None
)
mapped_input_value_len = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
break

for i in range(len(kwargs[any_input_key])):
for i in range(mapped_input_value_len):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
Expand Down
14 changes: 8 additions & 6 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ def _raw_execute(self, **kwargs) -> Any:
outputs_expected = False
outputs = []

any_input_key = (
list(self._run_task.interface.inputs.keys())[0]
if self._run_task.interface.inputs.items() is not None
else None
)
mapped_input_value_len = 0
if self._run_task.interface.inputs.items():
for k in self._run_task.interface.inputs.keys():
v = kwargs[k]
if isinstance(v, list) and k not in self.bound_inputs:
mapped_input_value_len = len(v)
break

for i in range(len(kwargs[any_input_key])):
for i in range(mapped_input_value_len):
single_instance_inputs = {}
for k in self.interface.inputs.keys():
v = kwargs[k]
Expand Down
24 changes: 24 additions & 0 deletions tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,27 @@ def many_outputs(a: int) -> (int, str):

with pytest.raises(ValueError):
_ = array_node_map_task(many_outputs)


def test_parameter_order():
@task()
def task1(a: int, b: float, c: str) -> str:
return f"{a} - {b} - {c}"

@task()
def task2(b: float, c: str, a: int) -> str:
return f"{a} - {b} - {c}"

@task()
def task3(c: str, a: int, b: float) -> str:
return f"{a} - {b} - {c}"

param_a = [1, 2, 3]
param_b = [0.1, 0.2, 0.3]
param_c = "c"

m1 = array_node_map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = array_node_map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = array_node_map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]
24 changes: 24 additions & 0 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,27 @@ def my_wf1() -> typing.List[type_t]:
return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4])

my_wf1()


def test_map_task_parameter_order():
@task()
def task1(a: int, b: float, c: str) -> str:
return f"{a} - {b} - {c}"

@task()
def task2(b: float, c: str, a: int) -> str:
return f"{a} - {b} - {c}"

@task()
def task3(c: str, a: int, b: float) -> str:
return f"{a} - {b} - {c}"

param_a = [1, 2, 3]
param_b = [0.1, 0.2, 0.3]
param_c = "c"

m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b)
m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b)
m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b)

assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"]

0 comments on commit 08872b5

Please sign in to comment.