forked from flyteorg/flytekit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added mpi plugin Signed-off-by: Yuvraj <[email protected]> Co-authored-by: Ketan Umare <[email protected]> Signed-off-by: Robert Everson <[email protected]>
- Loading branch information
1 parent
9f5aa51
commit 8f6de3b
Showing
12 changed files
with
251 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Flytekit Kubeflow MPI Plugin | ||
|
||
This plugin uses the Kubeflow MPI Operator and provides an extremely simplified interface for executing distributed training. | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-kfmpi | ||
``` | ||
|
||
_Example coming soon!_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .task import MPIJob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
""" | ||
This Plugin adds the capability of running distributed MPI training to Flyte using backend plugins, natively on | ||
Kubernetes. It leverages `MPI Job <https://github.com/kubeflow/mpi-operator>`_ Plugin from kubeflow. | ||
""" | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List | ||
|
||
from flyteidl.plugins import mpi_pb2 as _mpi_task | ||
from google.protobuf.json_format import MessageToDict | ||
|
||
from flytekit import PythonFunctionTask | ||
from flytekit.extend import SerializationSettings, TaskPlugins | ||
from flytekit.models import common as _common | ||
|
||
|
||
class MPIJobModel(_common.FlyteIdlEntity): | ||
"""Model definition for MPI the plugin | ||
Args: | ||
num_workers: integer determining the number of worker replicas spawned in the cluster for this job | ||
(in addition to 1 master). | ||
num_launcher_replicas: Number of launcher server replicas to use | ||
slots: Number of slots per worker used in hostfile | ||
.. note:: | ||
Please use resources=Resources(cpu="1"...) to specify per worker resource | ||
""" | ||
|
||
def __init__(self, num_workers, num_launcher_replicas, slots): | ||
self._num_workers = num_workers | ||
self._num_launcher_replicas = num_launcher_replicas | ||
self._slots = slots | ||
|
||
@property | ||
def num_workers(self): | ||
return self._num_workers | ||
|
||
@property | ||
def num_launcher_replicas(self): | ||
return self._num_launcher_replicas | ||
|
||
@property | ||
def slots(self): | ||
return self._slots | ||
|
||
def to_flyte_idl(self): | ||
return _mpi_task.DistributedMPITrainingTask( | ||
num_workers=self.num_workers, num_launcher_replicas=self.num_launcher_replicas, slots=self.slots | ||
) | ||
|
||
@classmethod | ||
def from_flyte_idl(cls, pb2_object): | ||
return cls( | ||
num_workers=pb2_object.num_workers, | ||
num_launcher_replicas=pb2_object.num_launcher_replicas, | ||
slots=pb2_object.slots, | ||
) | ||
|
||
|
||
@dataclass | ||
class MPIJob(object): | ||
""" | ||
Configuration for an executable `MPI Job <https://github.com/kubeflow/mpi-operator>`_. Use this | ||
to run distributed training on k8s with MPI | ||
Args: | ||
num_workers: integer determining the number of worker replicas spawned in the cluster for this job | ||
(in addition to 1 master). | ||
num_launcher_replicas: Number of launcher server replicas to use | ||
slots: Number of slots per worker used in hostfile | ||
""" | ||
|
||
slots: int | ||
num_launcher_replicas: int = 1 | ||
num_workers: int = 1 | ||
|
||
|
||
class MPIFunctionTask(PythonFunctionTask[MPIJob]): | ||
""" | ||
Plugin that submits a MPIJob (see https://github.com/kubeflow/mpi-operator) | ||
defined by the code within the _task_function to k8s cluster. | ||
""" | ||
|
||
_MPI_JOB_TASK_TYPE = "mpi" | ||
_MPI_BASE_COMMAND = [ | ||
"mpirun", | ||
"--allow-run-as-root", | ||
"-bind-to", | ||
"none", | ||
"-map-by", | ||
"slot", | ||
"-x", | ||
"LD_LIBRARY_PATH", | ||
"-x", | ||
"PATH", | ||
"-x", | ||
"NCCL_DEBUG=INFO", | ||
"-mca", | ||
"pml", | ||
"ob1", | ||
"-mca", | ||
"btl", | ||
"^openib", | ||
] | ||
|
||
def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): | ||
super().__init__( | ||
task_config=task_config, | ||
task_function=task_function, | ||
task_type=self._MPI_JOB_TASK_TYPE, | ||
**kwargs, | ||
) | ||
|
||
def get_command(self, settings: SerializationSettings) -> List[str]: | ||
cmd = super().get_command(settings) | ||
num_procs = self.task_config.num_workers * self.task_config.slots | ||
mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python", settings.entrypoint_settings.path] + cmd | ||
# the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile | ||
return mpi_cmd | ||
|
||
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: | ||
job = MPIJobModel( | ||
num_workers=self.task_config.num_workers, | ||
num_launcher_replicas=self.task_config.num_launcher_replicas, | ||
slots=self.task_config.slots, | ||
) | ||
return MessageToDict(job.to_flyte_idl()) | ||
|
||
|
||
# Register the MPI Plugin into the flytekit core plugin system | ||
TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
. | ||
-e file:.#egg=flytekitplugins-kfmpi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# | ||
# This file is autogenerated by pip-compile with python 3.8 | ||
# To update, run: | ||
# | ||
# pip-compile requirements.in | ||
# | ||
-e file:.#egg=flytekitplugins-kfmpi | ||
# via -r requirements.in |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from setuptools import setup | ||
|
||
PLUGIN_NAME = "kfmpi" | ||
|
||
microlib_name = f"flytekitplugins-{PLUGIN_NAME}" | ||
|
||
plugin_requires = ["flytekit>=0.16.0b0,<1.0.0", "flyteidl>=0.21.4"] | ||
|
||
__version__ = "0.0.0+develop" | ||
|
||
setup( | ||
name=microlib_name, | ||
version=__version__, | ||
author="flyteorg", | ||
author_email="[email protected]", | ||
description="K8s based MPI plugin for flytekit", | ||
namespace_packages=["flytekitplugins"], | ||
packages=[f"flytekitplugins.{PLUGIN_NAME}"], | ||
install_requires=plugin_requires, | ||
license="apache2", | ||
python_requires=">=3.6", | ||
classifiers=[ | ||
"Intended Audience :: Science/Research", | ||
"Intended Audience :: Developers", | ||
"License :: OSI Approved :: Apache Software License", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
], | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel | ||
|
||
from flytekit import Resources, task | ||
from flytekit.core.context_manager import EntrypointSettings | ||
from flytekit.extend import Image, ImageConfig, SerializationSettings | ||
|
||
|
||
def test_mpi_model_task(): | ||
job = MPIJobModel( | ||
num_workers=1, | ||
num_launcher_replicas=1, | ||
slots=1, | ||
) | ||
assert job.num_workers == 1 | ||
assert job.num_launcher_replicas == 1 | ||
assert job.slots == 1 | ||
assert job.from_flyte_idl(job.to_flyte_idl()) | ||
|
||
|
||
def test_mpi_task(): | ||
@task( | ||
task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), | ||
requests=Resources(cpu="1"), | ||
cache=True, | ||
cache_version="1", | ||
) | ||
def my_mpi_task(x: int, y: str) -> int: | ||
return x | ||
|
||
assert my_mpi_task(x=10, y="hello") == 10 | ||
|
||
assert my_mpi_task.task_config is not None | ||
|
||
default_img = Image(name="default", fqn="test", tag="tag") | ||
settings = SerializationSettings( | ||
project="project", | ||
domain="domain", | ||
version="version", | ||
env={"FOO": "baz"}, | ||
image_config=ImageConfig(default_image=default_img, images=[default_img]), | ||
entrypoint_settings=EntrypointSettings(path="/etc/my-entrypoint", command="my-entrypoint"), | ||
) | ||
|
||
assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} | ||
assert my_mpi_task.task_type == "mpi" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters