Skip to content

Commit

Permalink
add rdzv_configs to kfpytorch elastic (#1751)
Browse files Browse the repository at this point in the history
* add rdzv_configs to kfpytorch elastic
Signed-off-by: Nan2018 <[email protected]>

* address cr comments
Signed-off-by: Nan2018 <[email protected]>
  • Loading branch information
Nan2018 authored and Fabio Grätz committed Aug 14, 2023
1 parent 902af4c commit ee17dc7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
4 changes: 4 additions & 0 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,16 @@ class Elastic(object):
start_method (str): Multiprocessing start method to use when creating workers.
monitor_interval (int): Interval, in seconds, to monitor the state of workers.
max_restarts (int): Maximum number of worker group restarts before failing.
rdzv_configs (Dict[str, Any]): Additional rendezvous configs to pass to torch elastic, e.g. `{"timeout": 1200, "join_timeout": 900}`.
See `torch.distributed.launcher.api.LaunchConfig` and `torch.distributed.elastic.rendezvous.dynamic_rendezvous.create_handler`.
"""

nnodes: Union[int, str] = 1
nproc_per_node: int = 1
start_method: str = "spawn"
monitor_interval: int = 5
max_restarts: int = 0
rdzv_configs: Dict[str, Any] = field(default_factory=dict)


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
Expand Down Expand Up @@ -295,6 +298,7 @@ def _execute(self, **kwargs) -> Any:
max_nodes=self.max_nodes,
nproc_per_node=self.task_config.nproc_per_node,
rdzv_backend=self.rdzv_backend, # rdzv settings
rdzv_configs=self.task_config.rdzv_configs,
rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
max_restarts=self.task_config.max_restarts,
monitor_interval=self.task_config.monitor_interval,
Expand Down
17 changes: 17 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import typing
from dataclasses import dataclass
from unittest import mock

import pytest
import torch
Expand Down Expand Up @@ -95,3 +96,19 @@ def test_task(n: int):
return n + 1

test_task(n=1)


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_rdzv_configs(start_method: str) -> None:
"""Test that rendezvous configs are passed to torch distributed."""
from torch.distributed.launcher.api import LaunchConfig

rdzv_configs = {"join_timeout": 10}

@task(task_config=Elastic(nnodes=1, nproc_per_node=2, start_method=start_method, rdzv_configs=rdzv_configs))
def test_task():
pass

with mock.patch("torch.distributed.launcher.api.LaunchConfig", side_effect=LaunchConfig) as mock_launch_config:
test_task()
assert mock_launch_config.call_args[1]["rdzv_configs"] == rdzv_configs

0 comments on commit ee17dc7

Please sign in to comment.