Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor resource limit in pytorch/tensorflow plugin #724

Merged
merged 3 commits into from
Oct 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Kubernetes. It leverages `Pytorch Job <https://github.com/kubeflow/pytorch-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models.plugins import task as _task_model

Expand All @@ -22,16 +22,9 @@ class PyTorch(object):
num_workers: integer determining the number of worker replicas spawned in the cluster for this job
(in addition to 1 master).

per_replica_requests: [optional] lower-bound resources for each replica spawned for this job
yindia marked this conversation as resolved.
Show resolved Hide resolved
(i.e. both for (main)master and workers). Default is set by platform-level configuration.

per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified
the scheduled resource may not have all the resources
"""

num_workers: int
per_replica_requests: Optional[Resources] = None
per_replica_limits: Optional[Resources] = None


class PyTorchFunctionTask(PythonFunctionTask[PyTorch]):
Expand All @@ -47,7 +40,7 @@ def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs):
task_config,
task_function,
task_type=self._PYTORCH_TASK_TYPE,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
Expand Down
7 changes: 6 additions & 1 deletion plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@


def test_pytorch_task():
@task(task_config=PyTorch(num_workers=10, per_replica_requests=Resources(cpu="1")), cache=True, cache_version="1")
@task(
task_config=PyTorch(num_workers=10),
cache=True,
cache_version="1",
requests=Resources(cpu="1"),
)
def my_pytorch_task(x: int, y: str) -> int:
return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
Kubernetes. It leverages `TF Job <https://github.com/kubeflow/tf-operator>`_ Plugin from kubeflow.
"""
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict

from google.protobuf.json_format import MessageToDict

from flytekit import PythonFunctionTask, Resources
from flytekit import PythonFunctionTask
from flytekit.extend import SerializationSettings, TaskPlugins
from flytekit.models.plugins import task as _task_model

Expand All @@ -26,18 +26,11 @@ class TfJob(object):

num_chief_replicas: Number of chief replicas to use

per_replica_requests: [optional] lower-bound resources for each replica spawned for this job
(i.e. both for (main)master and workers). Default is set by platform-level configuration.

per_replica_limits: [optional] upper-bound resources for each replica spawned for this job. If not specified
the scheduled resource may not have all the resources
"""

num_workers: int
num_ps_replicas: int
num_chief_replicas: int
per_replica_requests: Optional[Resources] = None
per_replica_limits: Optional[Resources] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
Expand All @@ -53,7 +46,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
task_type=self._TF_JOB_TASK_TYPE,
task_config=task_config,
task_function=task_function,
**{**kwargs, "requests": task_config.per_replica_requests, "limits": task_config.per_replica_limits}
**kwargs,
)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

def test_tensorflow_task():
@task(
task_config=TfJob(
num_workers=10, per_replica_requests=Resources(cpu="1"), num_ps_replicas=1, num_chief_replicas=1
),
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
cache=True,
requests=Resources(cpu="1"),
cache_version="1",
)
def my_tensorflow_task(x: int, y: str) -> int:
Expand Down