diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 176f48a4205..99247e4af31 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -11,10 +11,22 @@ from flytekit import PythonFunctionTask from flytekit.extend import SerializationSettings, TaskPlugins from flytekit.models import common as _common -from flytekit.models import model as _model +from flytekit.models import task as _task_model class MPIJobModel(_common.FlyteIdlEntity): + """It will define the model for MPI plugin + + Args: + num_workers: integer determining the number of worker replicas spawned in the cluster for this job + (in addition to 1 master). + + num_launcher_replicas: Number of launcher server replicas to use + + slots: Number of slots per worker used in hostfile + + """ + def __init__(self, num_workers, num_launcher_replicas, slots): self._num_workers = num_workers self._num_launcher_replicas = num_launcher_replicas @@ -70,8 +82,8 @@ class MPIJob(object): slots: int num_launcher_replicas: int = 1 num_workers: int = 1 - per_replica_requests: Optional[_model.Resources] = None - per_replica_limits: Optional[_model.Resources] = None + per_replica_requests: Optional[_task_model.Resources] = None + per_replica_limits: Optional[_task_model.Resources] = None class MPIFunctionTask(PythonFunctionTask[MPIJob]): @@ -107,7 +119,7 @@ def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): task_config=task_config, task_function=task_function, task_type=self._MPI_JOB_TASK_TYPE, - **{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits} + **{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}, ) def get_command(self, settings: SerializationSettings) -> List[str]: diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 03895b72e20..2b2c3772d79 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -30,6 +30,4 @@ def my_mpi_task(x: int, y: str) -> int: ) assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} - assert my_mpi_task.resources.limits == Resources() - assert my_mpi_task.resources.requests == Resources(cpu="1") assert my_mpi_task.task_type == "mpi"