Skip to content

Commit

Permalink
Refactor resource limit in pytorch/tensorflow plugin (#724)
Browse files Browse the repository at this point in the history
* Refactor resource limit in pytorch/tensorflow plugin

Signed-off-by: Yuvraj <[email protected]>
  • Loading branch information
yindia authored Oct 29, 2021
1 parent 87131a0 commit 19d977e
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 24 deletions.
13 changes: 3 additions & 10 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Kubernetes. It leverages `Pytorch Job <https://github.com/kubeflow/pytorch-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models.plugins import task as _task_model

Expand All @@ -22,16 +22,9 @@ class PyTorch(object):
num_workers: integer determining the number of worker replicas spawned in the cluster for this job
(in addition to 1 master).
per_replica_requests: [optional] lower-bound resources for each replica spawned for this job
(i.e. both for (main)master and workers). Default is set by platform-level configuration.
per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified
the scheduled resource may not have all the resources
"""

num_workers: int
per_replica_requests: Optional[Resources] = None
per_replica_limits: Optional[Resources] = None


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
Expand All @@ -47,7 +40,7 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
task_config,
task_function,
task_type=self._PYTORCH_TASK_TYPE,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
Expand Down
7 changes: 6 additions & 1 deletion plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@


def test_pytorch_task():
@task(task_config=PyTorch(num_workers=10, per_replica_requests=Resources(cpu="1")), cache=True, cache_version="1")
@task(
task_config=PyTorch(num_workers=10),
cache=True,
cache_version="1",
requests=Resources(cpu="1"),
)
def my_pytorch_task(x: int, y: str) -> int:
return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Kubernetes. It leverages `TF Job <https://github.com/kubeflow/tf-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models.plugins import task as _task_model

Expand All @@ -26,18 +26,11 @@ class TfJob(object):
num_chief_replicas: Number of chief replicas to use
per_replica_requests: [optional] lower-bound resources for each replica spawned for this job
(i.e. both for (main)master and workers). Default is set by platform-level configuration.
per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified
the scheduled resource may not have all the resources
"""

num_workers: int
num_ps_replicas: int
num_chief_replicas: int
per_replica_requests: Optional[Resources] = None
per_replica_limits: Optional[Resources] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
Expand All @@ -53,7 +46,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
task_type=self._TF_JOB_TASK_TYPE,
task_config=task_config,
task_function=task_function,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
Expand Down
5 changes: 2 additions & 3 deletions plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

def test_tensorflow_task():
@task(
task_config=TfJob(
num_workers=10, per_replica_requests=Resources(cpu="1"), num_ps_replicas=1, num_chief_replicas=1
),
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
cache=True,
requests=Resources(cpu="1"),
cache_version="1",
)
def my_tensorflow_task(x: int, y: str) -> int:
Expand Down

0 comments on commit 19d977e

Please sign in to comment.