Skip to content

Commit

Permalink
Implement horovod task in mpi plugin (#1575)
Browse files Browse the repository at this point in the history
* Add horovod task to mpi plugin

Signed-off-by: byhsu <byhsu@linkedin.com>

* Remove unused

Signed-off-by: byhsu <byhsu@linkedin.com>

* add test

Signed-off-by: byhsu <byhsu@linkedin.com>

* inherit from mpifunctiontask

Signed-off-by: byhsu <byhsu@linkedin.com>

* fix format

Signed-off-by: byhsu <byhsu@linkedin.com>

* fix format

Signed-off-by: byhsu <byhsu@linkedin.com>

---------

Signed-off-by: byhsu <byhsu@linkedin.com>
Co-authored-by: byhsu <byhsu@linkedin.com>
ByronHsu and ByronHsu authored Apr 19, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 4251751 commit 7f3b389
Showing 3 changed files with 82 additions and 2 deletions.
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py
Original file line number Diff line number Diff line change
@@ -10,4 +10,4 @@
MPIJob
"""

from .task import MPIJob
from .task import HorovodJob, MPIJob
57 changes: 57 additions & 0 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
@@ -133,5 +133,62 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return MessageToDict(job.to_flyte_idl())


@dataclass
class HorovodJob(object):
slots: int
num_launcher_replicas: int = 1
num_workers: int = 1


class HorovodFunctionTask(MPIFunctionTask):
"""
For more info, check out https://github.com/horovod/horovod
"""

# Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod.
ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config"
discovery_script_path = "/etc/mpi/discover_hosts.sh"

def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs):

super().__init__(
task_config=task_config,
task_function=task_function,
**kwargs,
)

def get_command(self, settings: SerializationSettings) -> List[str]:
cmd = super().get_command(settings)
mpi_cmd = self._get_horovod_prefix() + cmd
return mpi_cmd

def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
config = super().get_config(settings)
return {**config, "worker_spec_command": self.ssh_command}

def _get_horovod_prefix(self) -> List[str]:
np = self.task_config.num_workers * self.task_config.slots
base_cmd = [
"horovodrun",
"-np",
f"{np}",
"--verbose",
"--log-level",
"INFO",
"--network-interface",
"eth0",
"--min-np",
f"{np}",
"--max-np",
f"{np}",
"--slots-per-host",
f"{self.task_config.slots}",
"--host-discovery-script",
self.discovery_script_path,
]
return base_cmd


# Register the MPI Plugin into the flytekit core plugin system
TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask)
TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask)
25 changes: 24 additions & 1 deletion plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel
from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel

from flytekit import Resources, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
@@ -41,3 +41,26 @@ def my_mpi_task(x: int, y: str) -> int:

assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1}
assert my_mpi_task.task_type == "mpi"


def test_horovod_task():
@task(
task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1),
)
def my_horovod_task():
...

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
cmd = my_horovod_task.get_command(settings)
assert "horovodrun" in cmd
config = my_horovod_task.get_config(settings)
assert "/usr/sbin/sshd" in config["worker_spec_command"]
custom = my_horovod_task.get_custom(settings)
assert isinstance(custom, dict) is True

0 comments on commit 7f3b389

Please sign in to comment.