diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 176f48a4205..c0cecb6d769 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -11,10 +11,22 @@ from flytekit import PythonFunctionTask from flytekit.extend import SerializationSettings, TaskPlugins from flytekit.models import common as _common -from flytekit.models import model as _model +from flytekit.models import task as _task_model class MPIJobModel(_common.FlyteIdlEntity): + """It will define the model for MPI plugin + + Args: + num_workers: integer determining the number of worker replicas spawned in the cluster for this job + (in addition to 1 master). + + num_launcher_replicas: Number of launcher server replicas to use + + slots: Number of slots per worker used in hostfile + + """ + def __init__(self, num_workers, num_launcher_replicas, slots): self._num_workers = num_workers self._num_launcher_replicas = num_launcher_replicas @@ -70,8 +82,8 @@ class MPIJob(object): slots: int num_launcher_replicas: int = 1 num_workers: int = 1 - per_replica_requests: Optional[_model.Resources] = None - per_replica_limits: Optional[_model.Resources] = None + per_replica_requests: Optional[_task_model.Resources] = None + per_replica_limits: Optional[_task_model.Resources] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -107,7 +119,7 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): task_config=task_config, task_function=task_function, task_type=self._MPI_JOB_TASK_TYPE, - **{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits} + **{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}, ) def get_command(self, settings: SerializationSettings) -> List[str]: @@ -116,7 +128,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: mpi_cmd = ( self._MPI_BASE_COMMAND + ["-np", f"{num_procs}"] - + ["python", settings.entrypoint_settings.path, "pyflyte-execute"] + + ["python", settings.entrypoint_settings.path] + cmd[1:] ) # the hostfile is set automatically by MPIOperator using env variable OMPI_MCA_orte_default_hostfile diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 03895b72e20..014237fc0f3 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,10 +1,22 @@ -from flytekitplugins.kfmpi.task import MPIJob +from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel from flytekit import Resources, task from flytekit.core.context_manager import EntrypointSettings from flytekit.extend import Image, ImageConfig, SerializationSettings +def test_mpi_model_task(): + job = MPIJobModel( + num_workers=1, + num_launcher_replicas=1, + slots=1, + ) + assert job.num_workers == 1 + assert job.num_launcher_replicas == 1 + assert job.slots == 1 + assert job.from_flyte_idl(job.to_flyte_idl()) + + def test_mpi_task(): @task( task_config=MPIJob(num_workers=10, num_launcher_replicas=10, slots=1), @@ -30,6 +42,4 @@ def my_mpi_task(x: int, y: str) -> int: ) assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} - assert my_mpi_task.resources.limits == Resources() - assert my_mpi_task.resources.requests == Resources(cpu="1") assert my_mpi_task.task_type == "mpi" diff --git a/setup.py b/setup.py index 4bbfeb3933c..ea84a5dcf5f 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,6 @@ "wrapt>=1.0.0,<2.0.0", "retry==0.9.2", "dataclasses-json>=0.5.2", - "jsonschema==3.2.0", "marshmallow-jsonschema>=0.12.0", "natsort>=7.0.1", "dirhash>=0.2.1",