diff --git a/plugins/flytekit-kf-pytorch/README.md b/plugins/flytekit-kf-pytorch/README.md index 280fe687b6..7de27502bf 100644 --- a/plugins/flytekit-kf-pytorch/README.md +++ b/plugins/flytekit-kf-pytorch/README.md @@ -2,6 +2,9 @@ This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends. +This plugin can execute torch elastic training, which is equivalent to run `torchrun`. Elastic training can be executed +in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner. + To install the plugin, run the following command: ```bash diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py index aedb0b192f..771b300926 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/__init__.py @@ -10,4 +10,4 @@ PyTorch """ -from .task import PyTorch +from .task import Elastic, PyTorch diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py deleted file mode 100644 index 517f4a9eb6..0000000000 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/models.py +++ /dev/null @@ -1,23 +0,0 @@ -from flyteidl.plugins import pytorch_pb2 as _pytorch_task - -from flytekit.models import common as _common - - -class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): - self._workers_count = workers_count - - @property - def workers_count(self): - return self._workers_count - - def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) - - @classmethod - def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 4b0bde78b0..d79fa785f6 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -2,16 +2,20 @@ This Plugin adds the capability of running distributed pytorch training to Flyte using backend plugins, natively on Kubernetes. It leverages `Pytorch Job `_ Plugin from kubeflow. """ +import os from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Union +import cloudpickle +from flyteidl.plugins.pytorch_pb2 import DistributedPyTorchTrainingTask, ElasticConfig from google.protobuf.json_format import MessageToDict +import flytekit from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings -from flytekit.extend import TaskPlugins +from flytekit.extend import IgnoreOutputs, TaskPlugins -from .models import PyTorchJob +TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @dataclass @@ -29,6 +33,31 @@ class PyTorch(object): num_workers: int +@dataclass +class Elastic(object): + """ + Configuration for `torch elastic training `_. + + Use this to run single- or multi-node distributed pytorch elastic training on k8s. + + Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1. + Multi-node training is executed otherwise using a `Pytorch Job `_. + + Args: + nnodes (Union[int, str]): Number of nodes, or the range of nodes in form :. + nproc_per_node (Union[int, str]): Number of workers per node. Supported values are [auto, cpu, gpu, int]. + start_method (str): Multiprocessing start method to use when creating workers. + monitor_interval (int): Interval, in seconds, to monitor the state of workers. + max_restarts (int): Maximum number of worker group restarts before failing. + """ + + nnodes: Union[int, str] = 1 + nproc_per_node: Union[int, str] = "auto" + start_method: str = "spawn" + monitor_interval: int = 5 + max_restarts: int = 0 + + class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ Plugin that submits a PyTorchJob (see https://github.com/kubeflow/pytorch-operator) @@ -46,9 +75,171 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = PyTorchJob(workers_count=self.task_config.num_workers) - return MessageToDict(job.to_flyte_idl()) + job = DistributedPyTorchTrainingTask(workers=self.task_config.num_workers) + return MessageToDict(job) # Register the Pytorch Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask) + + +def spawn_helper(fn: bytes, kwargs) -> Any: + """Help to spawn worker processes. + + The purpose of this function is to 1) be pickleable so that it can be used with + the multiprocessing start method `spawn` and 2) to call a cloudpickle-serialized + function passed to it. This function itself doesn't have to be pickleable. Without + such a helper task functions, which are not pickleable, couldn't be used with the + start method `spawn`. + + Args: + fn (bytes): Cloudpickle-serialized target function to be executed in the worker process. + + Returns: + The return value of the received target function. + """ + fn = cloudpickle.loads(fn) + return_val = fn(**kwargs) + return return_val + + +class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): + """ + Plugin for distributed training with torch elastic/torchrun (see + https://pytorch.org/docs/stable/elastic/run.html). + """ + + _ELASTIC_TASK_TYPE = "pytorch" + _ELASTIC_TASK_TYPE_STANDALONE = "python-task" + + def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): + task_type = self._ELASTIC_TASK_TYPE_STANDALONE if task_config.nnodes == 1 else self._ELASTIC_TASK_TYPE + + super(PytorchElasticFunctionTask, self).__init__( + task_config=task_config, + task_type=task_type, + task_function=task_function, + **kwargs, + ) + try: + from torch.distributed import run + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.task_config.nnodes)) + + """ + c10d is the backend recommended by torch elastic. + https://pytorch.org/docs/stable/elastic/run.html#note-on-rendezvous-backend + + For c10d, no backend server has to be deployed. + https://pytorch.org/docs/stable/elastic/run.html#deployment + Instead, the workers will use the master's address as the rendezvous point. + """ + self.rdzv_backend = "c10d" + + def _execute(self, **kwargs) -> Any: + """ + This helper method will be invoked to execute the task. + + + Returns: + The result of rank zero. + """ + try: + from torch.distributed import run + from torch.distributed.launcher.api import LaunchConfig, elastic_launch + except ImportError: + raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) + + if isinstance(self.task_config.nproc_per_node, str): + nproc = run.determine_local_world_size(self.task_config.nproc_per_node) + else: + nproc = self.task_config.nproc_per_node + + config = LaunchConfig( + run_id=flytekit.current_context().execution_id.name, + min_nodes=self.min_nodes, + max_nodes=self.max_nodes, + nproc_per_node=nproc, + rdzv_backend=self.rdzv_backend, # rdzv settings + rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"), + max_restarts=self.task_config.max_restarts, + monitor_interval=self.task_config.monitor_interval, + start_method=self.task_config.start_method, + ) + + if self.task_config.start_method == "spawn": + """ + We use cloudpickle to serialize the non-pickleable task function. + The torch elastic launcher then launches the spawn_helper function (which is pickleable) + instead of the task function. This helper function, in the child-process, then deserializes + the task function, again with cloudpickle, and executes it. + """ + launcher_target_func = spawn_helper + + dumped_target_function = cloudpickle.dumps(self._task_function) + launcher_args = (dumped_target_function, kwargs) + elif self.task_config.start_method == "fork": + """ + The torch elastic launcher doesn't support passing kwargs to the target function, + only args. Flyte only works with kwargs. Thus, we create a closure which already has + the task kwargs bound. We tell the torch elastic launcher to start this function in + the child processes. + """ + + def fn_partial(): + """Closure of the task function with kwargs already bound.""" + return self._task_function(**kwargs) + + launcher_target_func = fn_partial + launcher_args = () + + else: + raise Exception("Bad start method") + + out = elastic_launch( + config=config, + entrypoint=launcher_target_func, + )(*launcher_args) + + # `out` is a dictionary of rank (not local rank) -> result + # Rank 0 returns the result of the task function + if 0 in out: + return out[0] + else: + raise IgnoreOutputs() + + def execute(self, **kwargs) -> Any: + """ + This method will be invoked to execute the task. + + Handles the exception scope for the `_execute` method. + """ + from flytekit.exceptions import scopes as exception_scopes + + return exception_scopes.user_entry_point(self._execute)(**kwargs) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + if self.task_config.nnodes == 1: + """ + Torch elastic distributed training is executed in a normal k8s pod so that this + works without the kubeflow train operator. + """ + return super().get_custom(settings) + else: + elastic_config = ElasticConfig( + rdzv_backend=self.rdzv_backend, + min_replicas=self.min_nodes, + max_replicas=self.max_nodes, + nproc_per_node=self.task_config.nproc_per_node, + max_restarts=self.task_config.max_restarts, + ) + job = DistributedPyTorchTrainingTask( + workers=self.max_nodes, + elastic_config=elastic_config, + ) + return MessageToDict(job) + + +# Register the PytorchElastic Plugin into the flytekit core plugin system +TaskPlugins.register_pythontask_plugin(Elastic, PytorchElasticFunctionTask) diff --git a/plugins/flytekit-kf-pytorch/requirements.txt b/plugins/flytekit-kf-pytorch/requirements.txt index 882b8864a9..46973017ac 100644 --- a/plugins/flytekit-kf-pytorch/requirements.txt +++ b/plugins/flytekit-kf-pytorch/requirements.txt @@ -1,42 +1,91 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements.in # -e file:.#egg=flytekitplugins-kfpytorch # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.1.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth certifi==2022.12.7 - # via requests + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography + # via + # azure-datalake-store + # cryptography chardet==5.1.0 # via binaryornot charset-normalizer==3.1.0 - # via requests + # via + # aiohttp + # requests click==8.1.3 # via # cookiecutter # flytekit cloudpickle==2.2.1 - # via flytekit + # via + # flytekit + # flytekitplugins-kfpytorch cookiecutter==2.1.1 # via flytekit -croniter==1.3.8 +croniter==1.3.14 # via flytekit -cryptography==39.0.2 - # via pyopenssl +cryptography==40.0.2 + # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 - # via retry + # via gcsfs deprecated==1.2.13 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit docker==6.0.1 # via flytekit @@ -44,31 +93,72 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.14 - # via flytekit -flytekit==1.3.1 +flyteidl==1.3.19 + # via + # flytekit + # flytekitplugins-kfpytorch +flytekit==1.5.0 # via flytekitplugins-kfpytorch +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit gitdb==4.0.10 # via gitpython gitpython==3.1.31 # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.8.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage googleapis-common-protos==1.59.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status -grpcio==1.51.3 +grpcio==1.54.0 # via # flytekit # grpcio-status -grpcio-status==1.51.3 +grpcio-status==1.54.0 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.1.0 + # via + # requests + # yarl +importlib-metadata==6.6.0 # via # flytekit # keyring +isodate==0.6.1 + # via azure-storage-blob jaraco-classes==3.2.3 # via keyring jinja2==3.1.2 @@ -77,10 +167,14 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +jmespath==1.0.1 + # via botocore joblib==1.2.0 # via flytekit keyring==23.13.1 # via flytekit +kubernetes==26.1.0 + # via flytekit markupsafe==2.1.2 # via jinja2 marshmallow==3.19.0 @@ -94,42 +188,68 @@ marshmallow-jsonschema==0.13.0 # via flytekit more-itertools==9.1.0 # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 # via typing-inspect natsort==8.3.1 # via flytekit -numpy==1.23.5 +numpy==1.24.3 # via # flytekit # pandas # pyarrow -packaging==23.0 +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 # via # docker # marshmallow pandas==1.5.3 # via flytekit -protobuf==4.22.1 +portalocker==2.7.0 + # via msal-extensions +protobuf==4.22.3 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 # via flytekit python-dateutil==2.8.2 # via + # adal # arrow + # botocore # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.7 # via flytekit @@ -137,7 +257,7 @@ python-slugify==8.0.1 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas @@ -145,21 +265,43 @@ pyyaml==6.0 # via # cookiecutter # flytekit + # kubernetes # responses -regex==2022.10.31 +regex==2023.3.23 # via docker-image-py requests==2.28.2 # via + # adal + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes responses==0.23.1 # via flytekit -retry==0.9.2 +rsa==4.9 + # via google-auth +s3fs==2023.4.0 # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -168,27 +310,40 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -types-pyyaml==6.0.12.8 +types-pyyaml==6.0.12.9 # via responses typing-extensions==4.5.0 # via + # aioitertools + # azure-core + # azure-storage-blob # flytekit # typing-inspect typing-inspect==0.8.0 # via dataclasses-json urllib3==1.26.15 # via + # botocore # docker # flytekit + # kubernetes # requests # responses websocket-client==1.5.1 - # via docker + # via + # docker + # kubernetes wheel==0.40.0 # via flytekit wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit +yarl==1.9.2 + # via aiohttp zipp==3.15.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index 5eb1d4c43f..23543ac7bc 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["cloudpickle", "flytekit>=1.3.0,<2.0.0", "flyteidl>=1.3.19"] __version__ = "0.0.0+develop" @@ -17,6 +17,9 @@ namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, + extras_require={ + "elastic": ["torch>=1.9.0"], + }, license="apache2", python_requires=">=3.8", classifiers=[ diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py new file mode 100644 index 0000000000..2ca6c9cc65 --- /dev/null +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -0,0 +1,67 @@ +import os +import typing +from dataclasses import dataclass + +import pytest +import torch +import torch.distributed as dist +from dataclasses_json import dataclass_json +from flytekitplugins.kfpytorch.task import Elastic + +from flytekit import task, workflow + + +@dataclass_json +@dataclass +class Config: + lr: float = 1e-5 + bs: int = 64 + name: str = "foo" + + +def dist_communicate() -> int: + """Communicate between distributed workers.""" + rank = torch.distributed.get_rank() + world_size = dist.get_world_size() + tensor = torch.tensor([5], dtype=torch.int64) + 2 * rank + world_size + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + return tensor.item() + + +def train(config: Config) -> typing.Tuple[str, Config, torch.nn.Module, int]: + """Mock training a model using torch-elastic for test purposes.""" + dist.init_process_group(backend="gloo") + + local_rank = os.environ["LOCAL_RANK"] + + out_model = torch.nn.Linear(1000, int(local_rank) + 1) + config.name = "elastic-test" + + distributed_result = dist_communicate() + + return f"result from local rank {local_rank}", config, out_model, distributed_result + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_end_to_end(start_method: str) -> None: + """Test that the workflow with elastic task runs end to end.""" + world_size = 2 + + train_task = task(train, task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method)) + + @workflow + def wf(config: Config = Config()) -> typing.Tuple[str, Config, torch.nn.Module, int]: + return train_task(config=config) + + r, cfg, m, distributed_result = wf() + assert "result from local rank 0" in r + assert cfg.name == "elastic-test" + assert m.in_features == 1000 + assert m.out_features == 1 + """ + The distributed result is calculated by the workers of the elastic train + task by performing a `dist.all_reduce` operation. The correct result can + only be obtained if the distributed process group is initialized correctly. + """ + assert distributed_result == sum([5 + 2 * rank + world_size for rank in range(world_size)])