From c7f20e67c11161d0afe77443db749ef80156e2d7 Mon Sep 17 00:00:00 2001 From: "Fabio M. Graetz, Ph.D" Date: Wed, 3 May 2023 03:29:30 +0200 Subject: [PATCH] Enable torch elastic training (torchrun) (#1603) Signed-off-by: Ketan Umare Signed-off-by: Fabio Graetz This PR brings [torch elastic training (`torchrun`)](https://pytorch.org/docs/stable/elastic/run.html) to the pytorch plugin: ```python from flytekitplugins.kfpytorch import Elastic @task( task_config=Elastic( replicas=4, nproc_per_node=4, ... ), ... ) def train(...): ... ``` https://github.com/flyteorg/flyte/issues/3614 --- plugins/flytekit-kf-pytorch/README.md | 3 + .../flytekitplugins/kfpytorch/__init__.py | 2 +- .../flytekitplugins/kfpytorch/models.py | 23 -- .../flytekitplugins/kfpytorch/task.py | 201 ++++++++++++++- plugins/flytekit-kf-pytorch/requirements.txt | 231 +++++++++++++++--- plugins/flytekit-kf-pytorch/setup.py | 5 +- .../tests/test_elastic_task.py | 67 +++++ 7 files changed, 462 insertions(+), 70 deletions(-) delete mode 100644 plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py create mode 100644 plugins/flytekit-kf-pytorch/tests/test_elastic_task.py diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 280fe687b6..7de27502bf 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -2,6 +2,9 @@ This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends. +This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed +in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner. + To install the plugin, run the following command: ```bash diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index aedb0b192f..771b300926 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -10,4 +10,4 @@ PyTorch """ -from .task import PyTorch +from .task import Elastic, PyTorch diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py deleted file mode 100644 index 517f4a9eb6..0000000000 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py +++ /dev/null @@ -1,23 +0,0 @@ -from flyteidl.plugins import pytorch_pb2 as _pytorch_task - -from flytekit.models import common as _common - - -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count - - @property - def workers_count(self): - return self._workers_count - - def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 4b0bde78b0..d79fa785f6 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -2,16 +2,20 @@ This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ +import os from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Union +import cloudpickle +from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask, ElasticConfig from google.protobuf.json_format import MessageToDict +import flytekit from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings -from flytekit.extend import TaskPlugins +from flytekit.extend import IgnoreOutputs, TaskPlugins -from .models import PyTorchJob +TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass @@ -29,6 +33,31 @@ class PyTorch(object): num_workers: int +@dataclass +class Elastic(object): + """ + Configuration for `torch elastic training `_. + + Use this to run single- or multi-node distributed pytorch elastic training on k8s. + + Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1. + Multi-node training is executed otherwise using a `Pytorch Job `_. + + Args: + nnodes (Union[int, str]): Number of nodes, or the range of nodes in form :. + nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int]. + start_method (str): Multiprocessing start method to use when creating workers. + monitor_interval (int): Interval, in seconds, to monitor the state of workers. + max_restarts (int): Maximum number of worker group restarts before failing. + """ + + nnodes: Union[int, str] = 1 + nproc_per_node: Union[int, str] = "auto" + start_method: str = "spawn" + monitor_interval: int = 5 + max_restarts: int = 0 + + class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator) @@ -46,9 +75,171 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = PyTorchJob(workers_count=self.task_config.num_workers) - return MessageToDict(job.to_flyte_idl()) + job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) + return MessageToDict(job) # Register the Pytorch Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask) + + +def spawn_helper(fn: bytes, kwargs) -> Any: + """Help to spawn worker processes. + + The purpose of this function is to 1) be pickleable so that it can be used with + the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized + function passed to it. This function itself doesn't have to be pickleable. Without + such a helper task functions, which are not pickleable, couldn't be used with the + start method `spawn`. + + Args: + fn (bytes): Cloudpickle-serialized target function to be executed in the worker process. + + Returns: + The return value of the received target function. + """ + fn = cloudpickle.loads(fn) + return_val = fn(**kwargs) + return return_val + + +class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): + """ + Plugin for distributed training with torch elastic/torchrun (see + https://pytorch.org/docs/stable/elastic/run.html). + """ + + _ELASTIC_TASK_TYPE = "pytorch" + _ELASTIC_TASK_TYPE_STANDALONE = "python-task" + + def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): + task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE + + super(PytorchElasticFunctionTask, self).__init__( + task_config=task_config, + task_type=task_type, + task_function=task_function, + **kwargs, + ) + try: + from torch.distributed import run + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes)) + + """ + c10d is the backend recommended by torch elastic. + https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend + + For c10d, no backend server has to be deployed. + https://pytorch.org/docs/stable/elastic/run.html#deployment + Instead, the workers will use the master's address as the rendezvous point. + """ + self.rdzv_backend = "c10d" + + def _execute(self, **kwargs) -> Any: + """ + This helper method will be invoked to execute the task. + + + Returns: + The result of rank zero. + """ + try: + from torch.distributed import run + from torch.distributed.launcher.api import LaunchConfig, elastic_launch + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + + if isinstance(self.task_config.nproc_per_node, str): + nproc = run.determine_local_world_size(self.task_config.nproc_per_node) + else: + nproc = self.task_config.nproc_per_node + + config = LaunchConfig( + run_id=flytekit.current_context().execution_id.name, + min_nodes=self.min_nodes, + max_nodes=self.max_nodes, + nproc_per_node=nproc, + rdzv_backend=self.rdzv_backend, # rdzv settings + rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"), + max_restarts=self.task_config.max_restarts, + monitor_interval=self.task_config.monitor_interval, + start_method=self.task_config.start_method, + ) + + if self.task_config.start_method == "spawn": + """ + We use cloudpickle to serialize the non-pickleable task function. + The torch elastic launcher then launches the spawn_helper function (which is pickleable) + instead of the task function. This helper function, in the child-process, then deserializes + the task function, again with cloudpickle, and executes it. + """ + launcher_target_func = spawn_helper + + dumped_target_function = cloudpickle.dumps(self._task_function) + launcher_args = (dumped_target_function, kwargs) + elif self.task_config.start_method == "fork": + """ + The torch elastic launcher doesn't support passing kwargs to the target function, + only args. Flyte only works with kwargs. Thus, we create a closure which already has + the task kwargs bound. We tell the torch elastic launcher to start this function in + the child processes. + """ + + def fn_partial(): + """Closure of the task function with kwargs already bound.""" + return self._task_function(**kwargs) + + launcher_target_func = fn_partial + launcher_args = () + + else: + raise Exception("Bad start method") + + out = elastic_launch( + config=config, + entrypoint=launcher_target_func, + )(*launcher_args) + + # `out` is a dictionary of rank (not local rank) -> result + # Rank 0 returns the result of the task function + if 0 in out: + return out[0] + else: + raise IgnoreOutputs() + + def execute(self, **kwargs) -> Any: + """ + This method will be invoked to execute the task. + + Handles the exception scope for the `_execute` method. + """ + from flytekit.exceptions import scopes as exception_scopes + + return exception_scopes.user_entry_point(self._execute)(**kwargs) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + if self.task_config.nnodes == 1: + """ + Torch elastic distributed training is executed in a normal k8s pod so that this + works without the kubeflow train operator. + """ + return super().get_custom(settings) + else: + elastic_config = ElasticConfig( + rdzv_backend=self.rdzv_backend, + min_replicas=self.min_nodes, + max_replicas=self.max_nodes, + nproc_per_node=self.task_config.nproc_per_node, + max_restarts=self.task_config.max_restarts, + ) + job = DistributedPyTorchTrainingTask( + workers=self.max_nodes, + elastic_config=elastic_config, + ) + return MessageToDict(job) + + +# Register the PytorchElastic Plugin into the flytekit core plugin system +TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 96fa577a3e..ac3c2c174a 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,44 +1,91 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfpytorch # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.1.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot -charset-normalizer==3.0.1 - # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit cloudpickle==2.2.1 - # via flytekit + # via + # flytekit + # flytekitplugins-kfpytorch cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.14 # via flytekit -cryptography==39.0.1 +cryptography==40.0.2 # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit docker==6.0.1 # via flytekit @@ -46,13 +93,55 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.2.9 - # via flytekit -flytekit==1.2.7 +flyteidl==1.2.10 + # via + # flytekit + # flytekitplugins-kfpytorch +flytekit==1.2.9 # via flytekitplugins-kfpytorch -googleapis-common-protos==1.58.0 +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.8.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.0 # via # flyteidl + # flytekit + # google-api-core # grpcio-status grpcio==1.48.2 # via @@ -61,14 +150,16 @@ grpcio==1.48.2 grpcio-status==1.48.2 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.0.0 + # via + # requests + # yarl +importlib-metadata==6.6.0 # via # click # flytekit # keyring -importlib-resources==5.12.0 - # via keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jeepney==0.8.0 @@ -81,10 +172,14 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -98,43 +193,68 @@ marshmallow-jsonschema==0.13.0 # via flytekit more-itertools==9.0.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.2.0 # via flytekit -numpy==1.21.6 +numpy==1.24.3 # via # flytekit # pandas # pyarrow -packaging==23.0 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 # via # docker # marshmallow pandas==1.3.5 # via flytekit +portalocker==2.7.0 + # via msal-extensions protobuf==3.20.3 # via # flyteidl - # flytekit + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 # via flytekit python-dateutil==2.8.2 # via + # adal # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -142,7 +262,7 @@ python-slugify==8.0.0 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas @@ -150,17 +270,34 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.10.31 + # kubernetes + # responses +regex==2023.3.23 # via docker-image-py requests==2.28.2 # via + # adal + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses -responses==0.22.0 +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 # via flytekit -retry==0.9.2 +rsa==4.9 + # via google-auth +s3fs==2023.4.0 # via flytekit secretstorage==3.3.3 # via keyring @@ -168,21 +305,27 @@ singledispatchmethod==1.0 # via flytekit six==1.16.0 # via - # grpcio + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes # python-dateutil +smmap==5.0.0 + # via gitdb sortedcontainers==2.4.0 # via flytekit statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -toml==0.10.2 - # via responses -types-toml==0.10.8.5 +types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 # via - # arrow + # aioitertools + # azure-core + # azure-storage-blob # flytekit # importlib-metadata # responses @@ -191,19 +334,27 @@ typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.14 # via + # botocore # docker # flytekit + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker -wheel==0.38.4 + # via + # docker + # kubernetes +wheel==0.40.0 # via flytekit wrapt==1.14.1 # via + # aiobotocore # deprecated # flytekit -zipp==3.14.0 - # via - # importlib-metadata - # importlib-resources +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index c45e409567..a207b9381e 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.3.0,<2.0.0"] +plugin_requires = ["cloudpickle", "flytekit>=1.1.0b0,<1.3.0,<2.0.0", "flyteidl>=1.2.10,<1.3.0"] __version__ = "0.0.0+develop" @@ -17,6 +17,9 @@ namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, + extras_require={ + "elastic": ["torch>=1.9.0"], + }, license="apache2", python_requires=">=3.7", classifiers=[ diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py new file mode 100644 index 0000000000..2ca6c9cc65 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -0,0 +1,67 @@ +import os +import typing +from dataclasses import dataclass + +import pytest +import torch +import torch.distributed as dist +from dataclasses_json import dataclass_json +from flytekitplugins.kfpytorch.task import Elastic + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class Config: + lr: float = 1e-5 + bs: int = 64 + name: str = "foo" + + +def dist_communicate() -> int: + """Communicate between distributed workers.""" + rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + tensor = torch.tensor([5], dtype=torch.int64) + 2 * rank + world_size + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor.item() + + +def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]: + """Mock training a model using torch-elastic for test purposes.""" + dist.init_process_group(backend="gloo") + + local_rank = os.environ["LOCAL_RANK"] + + out_model = torch.nn.Linear(1000, int(local_rank) + 1) + config.name = "elastic-test" + + distributed_result = dist_communicate() + + return f"result from local rank {local_rank}", config, out_model, distributed_result + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_end_to_end(start_method: str) -> None: + """Test that the workflow with elastic task runs end to end.""" + world_size = 2 + + train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + + @workflow + def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: + return train_task(config=config) + + r, cfg, m, distributed_result = wf() + assert "result from local rank 0" in r + assert cfg.name == "elastic-test" + assert m.in_features == 1000 + assert m.out_features == 1 + """ + The distributed result is calculated by the workers of the elastic train + task by performing a `dist.all_reduce` operation. The correct result can + only be obtained if the distributed process group is initialized correctly. + """ + assert distributed_result == sum([5 + 2 * rank + world_size for rank in range(world_size)])