Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch plugin implementation #112

Merged
merged 4 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ If `@sidecar_task` is to be used, one should install the `sidecar` plugin.
pip install flytekit[sidecar]
```

### Pytorch

If `@pytorch_task` is to be used, one should install the `pytorch` plugin.

```bash
pip install flytekit[pytorch]
```

### Full Installation

To install all or multiple available plugins, one can specify them individually:
Expand Down
1 change: 1 addition & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SdkTaskType(object):
SIDECAR_TASK = "sidecar"
SENSOR_TASK = "sensor-task"
PRESTO_TASK = "presto"
PYTORCH_TASK = "pytorch"

GLOBAL_INPUT_NODE_ID = ''

Expand Down
80 changes: 80 additions & 0 deletions flytekit/common/tasks/pytorch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

import six as _six
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable
from flytekit.common.types import helpers as _type_helpers
from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict


class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer):

@property
def args(self):
"""
Override args to remove the injection of command prefixes
:rtype: list[Text]
"""
return self._args

class SdkPyTorchTask(_sdk_runnable.SdkRunnableTask):
def __init__(
self,
task_function,
task_type,
discovery_version,
retries,
interruptible,
deprecated,
discoverable,
timeout,
workers_count,
instance_storage_request,
instance_cpu_request,
instance_gpu_request,
instance_memory_request,
instance_storage_limit,
instance_cpu_limit,
instance_gpu_limit,
instance_memory_limit,
environment
):
pytorch_job = _task_models.PyTorchJob(
workers_count=workers_count
).to_flyte_idl()
super(SdkPyTorchTask, self).__init__(
task_function=task_function,
task_type=task_type,
discovery_version=discovery_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
storage_request=instance_storage_request,
cpu_request=instance_cpu_request,
gpu_request=instance_gpu_request,
memory_request=instance_memory_request,
storage_limit=instance_storage_limit,
cpu_limit=instance_cpu_limit,
gpu_limit=instance_gpu_limit,
memory_limit=instance_memory_limit,
discoverable=discoverable,
timeout=timeout,
environment=environment,
custom=_MessageToDict(pytorch_job)
)

def _get_container_definition(
self,
**kwargs
):
"""
:rtype: SdkRunnablePytorchContainer
"""
return super(SdkPyTorchTask, self)._get_container_definition(cls=SdkRunnablePytorchContainer, **kwargs)
22 changes: 22 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flyteidl.admin import task_pb2 as _admin_task
from flyteidl.core import tasks_pb2 as _core_task, literals_pb2 as _literals_pb2, compiler_pb2 as _compiler
from flyteidl.plugins import spark_pb2 as _spark_task
from flyteidl.plugins import pytorch_pb2 as _pytorch_task
from flytekit.plugins import flyteidl as _lazy_flyteidl
from google.protobuf import json_format as _json_format, struct_pb2 as _struct
from flytekit.sdk.spark_types import SparkType as _spark_type
Expand Down Expand Up @@ -804,3 +805,24 @@ def from_flyte_idl(cls, pb2_object):
pod_spec=pb2_object.pod_spec,
primary_container_name=pb2_object.primary_container_name,
)


class PyTorchJob(_common.FlyteIdlEntity):

def __init__(self, workers_count):
self._workers_count = workers_count

@property
def workers_count(self):
return self._workers_count

def to_flyte_idl(self):
return _pytorch_task.DistributedPyTorchTrainingTask(
workers=self.workers_count,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(
workers_count=pb2_object.workers,
)
8 changes: 8 additions & 0 deletions flytekit/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: types.ModuleType
type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes")

torch = _lazy_loader.lazy_load_module("torch") # type: types.ModuleType
igorvalko marked this conversation as resolved.
Show resolved Hide resolved

_lazy_loader.LazyLoadPlugin(
"spark",
["pyspark>=2.4.0,<3.0.0"],
Expand Down Expand Up @@ -46,3 +48,9 @@
],
[hmsclient]
)

_lazy_loader.LazyLoadPlugin(
"pytorch",
["torch>=1.0.0,<2.0.0"],
[torch]
)
156 changes: 155 additions & 1 deletion flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, \
sidecar_task as _sdk_sidecar_tasks, pytorch_task as _sdk_pytorch_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.types import helpers as _type_helpers
from flytekit.sdk.spark_types import SparkType as _spark_type
Expand Down Expand Up @@ -1000,3 +1001,156 @@ def wrapper(fn):
return wrapper(_task_function)
else:
return wrapper


def pytorch_task(
_task_function=None,
cache_version='',
retries=0,
interruptible=False,
deprecated='',
cache=False,
timeout=None,
workers_count=1,
instance_storage_request="",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to call it "instance_" ? is it to signify this request is per worker instead of for the entire task?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you've got it right: it's just to stress an accent that it's for requesting resources per instance (both for worker(s) and master)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe for the user we could provide an interface that says "master_cpu" "worker_cpu". But for now internally map it. Eventually I think we want to use - https://github.com/pytorch/elastic/tree/master/kubernetes, which will not require the master/worker model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually thinking more about it, lets replace "instance" -> "node"? and that should be it. WE could move to elastic later on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I see that in related docs (at least here and here) processes in distributed job are referred as replicas. Hence, may be let's go with per_replica_*? Any objections?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EngHabu I think what @igorvalko is trying to do is use same cpu/mem for both master and worker. Initially he had all resources as part of the cRD and i said you could re-purpose our "container" resources for one of them.
So Igor made them the same.
Now maybe the way to think about this this, the container in "flyte" for a distributed job refers to the master or worker?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid the Spark issue we had where people would assign huge resources to the "driver" pod the same way they assign them to the executor pods but they end up barely using any on the driver pod... by essentially forcing users to use the same for both, are we getting into the same situation?

Copy link
Contributor Author

@igorvalko igorvalko May 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no hands-on experience with pytorch, but it looks like master in pytorch paradigm does the same kind of job as workers with some extra communication arrangement work. In spark different approach is being used: driver (most often) is a lightweight coordinator and all the heavy work happens on executors.
Hence I assume that's uncommon thing to request uneven resources for master and workers. Input from anyone having experience with distributed pytorch would be very welcome.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I read, I think there are two possible paradigms: (1) a master node does what a worker does + the coordination and the communication, (2) a master node does only the reduction. But I think either of these requires the master node to have the same/similar resource requirement as workers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I am assuming we are using torch.nn.parallel.DistributedDataParallel() which builds on top of torch.distributed package, which should supposedly use the NCCL backend according to https://pytorch.org/docs/stable/distributed.html#which-backend-to-use)

instance_cpu_request="",
instance_gpu_request="",
instance_memory_request="",
instance_storage_limit="",
instance_cpu_limit="",
instance_gpu_limit="",
instance_memory_limit="",
environment=None,
cls=None
):
"""
Decorator to create a Pytorch Task definition. This task will submit PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
defined by the code within the _task_function to k8s cluster.

.. code-block:: python

@inputs(int_list=[Types.Integer])
@outputs(result=Types.Integer
@pytorch_task(
workers_count=2,
instance_cpu_request="500m",
instance_memory_request="4Gi",
instance_memory_limit="8Gi",
instance_gpu_limit="1",
)
def my_pytorch_job(wf_params, int_list, result):
pass

:param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must
take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword
arguments are allowed for wrapped task functions.

:param Text cache_version: [optional] string representing logical version for discovery. This field should be
updated whenever the underlying algorithm changes.

.. note::

This argument is required to be a non-empty string if `cache` is True.

:param int retries: [optional] integer determining number of times task can be retried on
:py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults
to 0.

.. note::

If retries > 0, the task must be able to recover from any remote state created within the user code. It is
strongly recommended that tasks are written to be idempotent.

:param bool interruptible: [optional] boolean describing if the task is interruptible.

:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
will be logged as a warning so it should contain information regarding how to update to a newer task.

:param bool cache: [optional] boolean describing if the outputs of this task should be cached and
re-usable.

:param datetime.timedelta timeout: [optional] describes how long the task should be allowed to
run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run
indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout.

:param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job.

:param Text instance_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
for each instance spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration.

.. note::

This is currently not supported by the platform.

:param Text instance_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each instance
spawned for this job (i.e. both for master and workers).
This can be set to a fractional portion of a CPU. Default is set by platform-level configuration.

TODO: Add links to resource string documentation for Kubernetes

:param Text instance_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each
instance spawned for this job (i.e. both for master and workers).
Default is set by platform-level configuration.

TODO: Add links to resource string documentation for Kubernetes

:param Text instance_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory
necessary for each instance spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration.

TODO: Add links to resource string documentation for Kubernetes

:param Text instance_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space
for each instance spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to storage_request.

.. note::

This is currently not supported by the platform.

:param Text instance_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each instance
spawned for this job (i.e. both for master and workers).
This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified,
it is set equal to cpu_request.

:param Text instance_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each
instance spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to gpu_request.

:param Text instance_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory
necessary for each instance spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to memory_request.

:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.

:param cls: This can be used to override the task implementation with a user-defined extension. The class
provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to
inject bespoke logic into the base Flyte programming model.

:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
"""
def wrapper(fn):
return (cls or _sdk_pytorch_tasks.SdkPyTorchTask)(
task_function=fn,
task_type=_common_constants.SdkTaskType.PYTORCH_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
workers_count=workers_count,
instance_storage_request=instance_storage_request,
instance_cpu_request=instance_cpu_request,
instance_gpu_request=instance_gpu_request,
instance_memory_request=instance_memory_request,
instance_storage_limit=instance_storage_limit,
instance_cpu_limit=instance_cpu_limit,
instance_gpu_limit=instance_gpu_limit,
instance_memory_limit=instance_memory_limit,
environment=environment or {}
)

if _task_function:
return wrapper(_task_function)
else:
return wrapper
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]
},
install_requires=[
"flyteidl>=0.17.27,<1.0.0",
"flyteidl>=0.17.32,<1.0.0",
"click>=6.6,<8.0",
"croniter>=0.3.20,<4.0.0",
"deprecation>=2.0,<3.0",
Expand Down
45 changes: 45 additions & 0 deletions tests/flytekit/unit/sdk/tasks/test_pytorch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import absolute_import
from flytekit.sdk.tasks import pytorch_task, inputs, outputs
from flytekit.sdk.types import Types
from flytekit.common import constants as _common_constants
from flytekit.common.tasks import sdk_runnable as _sdk_runnable, pytorch_task as _pytorch_task
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier
import datetime as _datetime


@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@pytorch_task(workers_count=1)
def simple_pytorch_task(wf_params, sc, in1, out1):
pass


simple_pytorch_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version")


def test_simple_pytorch_task():
assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask)
assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask)
assert simple_pytorch_task.interface.inputs['in1'].description == ''
assert simple_pytorch_task.interface.inputs['in1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER)
assert simple_pytorch_task.interface.outputs['out1'].description == ''
assert simple_pytorch_task.interface.outputs['out1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.STRING)
assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK
assert simple_pytorch_task.task_function_name == 'simple_pytorch_task'
assert simple_pytorch_task.task_module == __name__
assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_pytorch_task.metadata.deprecated_error_message == ''
assert simple_pytorch_task.metadata.discoverable is False
assert simple_pytorch_task.metadata.discovery_version == ''
assert simple_pytorch_task.metadata.retries.retries == 0
assert len(simple_pytorch_task.container.resources.limits) == 0
assert len(simple_pytorch_task.container.resources.requests) == 0
assert simple_pytorch_task.custom['workers'] == 1
# Should strip out the venv component of the args.
assert simple_pytorch_task._get_container_definition().args[0] == 'pyflyte-execute'

pb2 = simple_pytorch_task.to_flyte_idl()
assert pb2.custom['workers'] == 1