Skip to content

Commit

Permalink
Set default values to map task template (#841)
Browse files Browse the repository at this point in the history
* Set sane defaults in map task templates

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove unused method

Signed-off-by: Eduardo Apolinario <[email protected]>

* Put ArrayJob.from_dict back

Signed-off-by: Eduardo Apolinario <[email protected]>

* Define parallelism=0 as unbounded

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove special case to handle 0

Signed-off-by: Eduardo Apolinario <[email protected]>

Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Feb 8, 2022
1 parent 9dcf320 commit 64add74
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
4 changes: 2 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_success_ratio: float = None, **kwargs):
def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs):
"""
Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of
any individual :py:class:`flytekit.PythonFunctionTask`.
Expand All @@ -231,7 +231,7 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_suc
:param task_function: This argument is implicitly passed and represents the repeatable function
:param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed.
all inputs are processed. If left unspecified, this means unbounded concurrency.
:param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete
successfully before terminating this task and marking it successful.
Expand Down
31 changes: 23 additions & 8 deletions flytekit/models/array_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,21 @@ def to_dict(self):
"""
:rtype: dict[T, Text]
"""
return _json_format.MessageToDict(
_array_job.ArrayJob(
array_job = None
if self.min_successes is not None:
array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_successes=self.min_successes,
)
)
elif self.min_success_ratio is not None:
array_job = _array_job.ArrayJob(
parallelism=self.parallelism,
size=self.size,
min_success_ratio=self.min_success_ratio,
)

return _json_format.MessageToDict(array_job)

@classmethod
def from_dict(cls, idl_dict):
Expand All @@ -86,8 +94,15 @@ def from_dict(cls, idl_dict):
"""
pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob())

return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_successes=pb2_object.min_successes,
)
if pb2_object.HasField("min_successes"):
return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_successes=pb2_object.min_successes,
)
else:
return cls(
parallelism=pb2_object.parallelism,
size=pb2_object.size,
min_success_ratio=pb2_object.min_success_ratio,
)
50 changes: 32 additions & 18 deletions tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
from flytekit.tools.translator import get_serializable


@pytest.fixture
def serialization_settings():
default_img = Image(name="default", fqn="test", tag="tag")
return context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def t1(a: int) -> str:
b = a + 2
Expand Down Expand Up @@ -54,18 +66,12 @@ def test_map_task_types():
_ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"])


def test_serialization():
def test_serialization(serialization_settings):
maptask = map_task(t1, metadata=TaskMetadata(retries=1))
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

# By default all map_task tasks will have their custom fields set.
assert task_spec.template.custom["minSuccessRatio"] == 1.0
assert task_spec.template.type == "container_array"
assert task_spec.template.task_type_version == 1
assert task_spec.template.container.args == [
Expand All @@ -90,7 +96,23 @@ def test_serialization():
]


def test_serialization_workflow_def():
@pytest.mark.parametrize(
"custom_fields_dict, expected_custom_fields",
[
({}, {"minSuccessRatio": 1.0}),
({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}),
({"min_success_ratio": 0.271828}, {"minSuccessRatio": 0.271828}),
({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}),
],
)
def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings):
maptask = map_task(t1, **custom_fields_dict)
task_spec = get_serializable(OrderedDict(), serialization_settings, maptask)

assert task_spec.template.custom == expected_custom_fields


def test_serialization_workflow_def(serialization_settings):
@task
def complex_task(a: int) -> str:
b = a + 2
Expand All @@ -106,14 +128,6 @@ def w1(a: typing.List[int]) -> typing.List[str]:
def w2(a: typing.List[int]) -> typing.List[str]:
return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a)

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = context_manager.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
serialized_control_plane_entities = OrderedDict()
wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1)
assert wf1_spec.template is not None
Expand Down

0 comments on commit 64add74

Please sign in to comment.