Skip to content

Commit

Permalink
Merge branch 'feature/mpi-plugin' of github.com:flyteorg/flytekit int…
Browse files Browse the repository at this point in the history
…o feature/mpi-plugin
  • Loading branch information
yindia committed Oct 11, 2021
2 parents 15b361b + 5938b6c commit 078c7c4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
5 changes: 3 additions & 2 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from flyteidl.core import compiler_pb2 as _compiler
from flyteidl.core import literals_pb2 as _literals_pb2
from flyteidl.core import tasks_pb2 as _core_task
from flyteidl.plugins import mpi_pb2 as _mpi_task
from flyteidl.plugins import pytorch_pb2 as _pytorch_task
from flyteidl.plugins import spark_pb2 as _spark_task
from flyteidl.plugins import mpi_pb2 as _mpi_task
from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -1150,6 +1150,7 @@ def from_flyte_idl(cls, pb2_object):
chief_replicas_count=pb2_object.chief_replicas,
)


class MPIJob(_common.FlyteIdlEntity):
def __init__(self, num_workers, num_launcher_replicas, slots):
self._num_workers = num_workers
Expand Down Expand Up @@ -1179,4 +1180,4 @@ def from_flyte_idl(cls, pb2_object):
num_workers=pb2_object.num_workers,
num_launcher_replicas=pb2_object.num_launcher_replicas,
slots=pb2_object.slots,
)
)
23 changes: 9 additions & 14 deletions plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Kubernetes. It leverages `TF Job <https://github.com/kubeflow/mpi-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, List
from typing import Any, Callable, Dict, List

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models import task as model

Expand All @@ -25,19 +25,11 @@ class MPIJob(object):
num_launcher_replicas: Number of launcher server replicas to use
slots: Number of slots per worker used in hostfile
per_replica_requests: [optional] lower-bound resources for each replica spawned for this job
(i.e. both for (main)master and workers). Default is set by platform-level configuration.
per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified
the scheduled resource may not have all the resources
"""

slots: int
num_launcher_replicas: int = 1
num_workers: int = 1
per_replica_requests: Optional[Resources] = None
per_replica_limits: Optional[Resources] = None


class MPIFunctionTask(PythonFunctionTask[MPIJob]):
Expand Down Expand Up @@ -73,15 +65,18 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs):
task_config=task_config,
task_function=task_function,
task_type=self._MPI_JOB_TASK_TYPE,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}
**kwargs,
)

def get_command(self, settings: SerializationSettings) -> List[str]:
cmd = super().get_command(settings)
num_procs = self.task_config.num_workers * self.task_config.slots
mpi_cmd = self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] + ["python",
settings.entrypoint_settings.path,
"pyflyte-execute"] + cmd[1:]
mpi_cmd = (
self._MPI_BASE_COMMAND
+ ["-np", f"{num_procs}"]
+ ["python", settings.entrypoint_settings.path, "pyflyte-execute"]
+ cmd[1:]
)
# the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile
return mpi_cmd

Expand Down
9 changes: 7 additions & 2 deletions plugins/flytekit-kf-mpi/tests/test_mpi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@


def test_mpi_task():
@task(task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1, per_replica_requests=Resources(cpu="1")), cache=True, cache_version="1")
@task(
task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1),
requests=Resources(cpu="1"),
cache=True,
cache_version="1",
)
def my_mpi_task(x: int, y: str) -> int:
return x

Expand All @@ -24,7 +29,7 @@ def my_mpi_task(x: int, y: str) -> int:
entrypoint_settings=EntrypointSettings(path="/etc/my-entrypoint", command="my-entrypoint"),
)

assert my_mpi_task.get_custom(settings) == {"workers": 10}
assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1}
assert my_mpi_task.resources.limits == Resources()
assert my_mpi_task.resources.requests == Resources(cpu="1")
assert my_mpi_task.task_type == "mpi"

0 comments on commit 078c7c4

Please sign in to comment.