Skip to content

Commit

Permalink
pytorch plugin implementation (flyteorg#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
igorvalko authored May 28, 2020
1 parent 2fddff5 commit e908a59
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 2 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ If `@sidecar_task` is to be used, one should install the `sidecar` plugin.
pip install flytekit[sidecar]
```

### Pytorch

If `@pytorch_task` is to be used, one should install the `pytorch` plugin.

```bash
pip install flytekit[pytorch]
```

### Full Installation

To install all or multiple available plugins, one can specify them individually:
Expand Down
1 change: 1 addition & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SdkTaskType(object):
SIDECAR_TASK = "sidecar"
SENSOR_TASK = "sensor-task"
PRESTO_TASK = "presto"
PYTORCH_TASK = "pytorch"

GLOBAL_INPUT_NODE_ID = ''

Expand Down
80 changes: 80 additions & 0 deletions flytekit/common/tasks/pytorch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import absolute_import

try:
from inspect import getfullargspec as _getargspec
except ImportError:
from inspect import getargspec as _getargspec

import six as _six
from flytekit.common import constants as _constants
from flytekit.common.exceptions import scopes as _exception_scopes
from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable
from flytekit.common.types import helpers as _type_helpers
from flytekit.models import literals as _literal_models, task as _task_models
from google.protobuf.json_format import MessageToDict as _MessageToDict


class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer):

@property
def args(self):
"""
Override args to remove the injection of command prefixes
:rtype: list[Text]
"""
return self._args

class SdkPyTorchTask(_sdk_runnable.SdkRunnableTask):
def __init__(
self,
task_function,
task_type,
discovery_version,
retries,
interruptible,
deprecated,
discoverable,
timeout,
workers_count,
per_replica_storage_request,
per_replica_cpu_request,
per_replica_gpu_request,
per_replica_memory_request,
per_replica_storage_limit,
per_replica_cpu_limit,
per_replica_gpu_limit,
per_replica_memory_limit,
environment
):
pytorch_job = _task_models.PyTorchJob(
workers_count=workers_count
).to_flyte_idl()
super(SdkPyTorchTask, self).__init__(
task_function=task_function,
task_type=task_type,
discovery_version=discovery_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
storage_request=per_replica_storage_request,
cpu_request=per_replica_cpu_request,
gpu_request=per_replica_gpu_request,
memory_request=per_replica_memory_request,
storage_limit=per_replica_storage_limit,
cpu_limit=per_replica_cpu_limit,
gpu_limit=per_replica_gpu_limit,
memory_limit=per_replica_memory_limit,
discoverable=discoverable,
timeout=timeout,
environment=environment,
custom=_MessageToDict(pytorch_job)
)

def _get_container_definition(
self,
**kwargs
):
"""
:rtype: SdkRunnablePytorchContainer
"""
return super(SdkPyTorchTask, self)._get_container_definition(cls=SdkRunnablePytorchContainer, **kwargs)
22 changes: 22 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flyteidl.admin import task_pb2 as _admin_task
from flyteidl.core import tasks_pb2 as _core_task, literals_pb2 as _literals_pb2, compiler_pb2 as _compiler
from flyteidl.plugins import spark_pb2 as _spark_task
from flyteidl.plugins import pytorch_pb2 as _pytorch_task
from flytekit.plugins import flyteidl as _lazy_flyteidl
from google.protobuf import json_format as _json_format, struct_pb2 as _struct
from flytekit.sdk.spark_types import SparkType as _spark_type
Expand Down Expand Up @@ -804,3 +805,24 @@ def from_flyte_idl(cls, pb2_object):
pod_spec=pb2_object.pod_spec,
primary_container_name=pb2_object.primary_container_name,
)


class PyTorchJob(_common.FlyteIdlEntity):

def __init__(self, workers_count):
self._workers_count = workers_count

@property
def workers_count(self):
return self._workers_count

def to_flyte_idl(self):
return _pytorch_task.DistributedPyTorchTrainingTask(
workers=self.workers_count,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(
workers_count=pb2_object.workers,
)
8 changes: 8 additions & 0 deletions flytekit/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: types.ModuleType
type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes")

torch = _lazy_loader.lazy_load_module("torch") # type: types.ModuleType

_lazy_loader.LazyLoadPlugin(
"spark",
["pyspark>=2.4.0,<3.0.0"],
Expand Down Expand Up @@ -46,3 +48,9 @@
],
[hmsclient]
)

_lazy_loader.LazyLoadPlugin(
"pytorch",
["torch>=1.0.0,<2.0.0"],
[torch]
)
157 changes: 156 additions & 1 deletion flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks
spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, \
sidecar_task as _sdk_sidecar_tasks, pytorch_task as _sdk_pytorch_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.types import helpers as _type_helpers
from flytekit.sdk.spark_types import SparkType as _spark_type
Expand Down Expand Up @@ -1000,3 +1001,157 @@ def wrapper(fn):
return wrapper(_task_function)
else:
return wrapper


def pytorch_task(
_task_function=None,
cache_version='',
retries=0,
interruptible=False,
deprecated='',
cache=False,
timeout=None,
workers_count=1,
per_replica_storage_request="",
per_replica_cpu_request="",
per_replica_gpu_request="",
per_replica_memory_request="",
per_replica_storage_limit="",
per_replica_cpu_limit="",
per_replica_gpu_limit="",
per_replica_memory_limit="",
environment=None,
cls=None
):
"""
Decorator to create a Pytorch Task definition. This task will submit PyTorchJob (see https://github.com/kubeflow/pytorch-operator)
defined by the code within the _task_function to k8s cluster.
.. code-block:: python
@inputs(int_list=[Types.Integer])
@outputs(result=Types.Integer
@pytorch_task(
workers_count=2,
per_replica_cpu_request="500m",
per_replica_memory_request="4Gi",
per_replica_memory_limit="8Gi",
per_replica_gpu_limit="1",
)
def my_pytorch_job(wf_params, int_list, result):
pass
:param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must
take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword
arguments are allowed for wrapped task functions.
:param Text cache_version: [optional] string representing logical version for discovery. This field should be
updated whenever the underlying algorithm changes.
.. note::
This argument is required to be a non-empty string if `cache` is True.
:param int retries: [optional] integer determining number of times task can be retried on
:py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults
to 0.
.. note::
If retries > 0, the task must be able to recover from any remote state created within the user code. It is
strongly recommended that tasks are written to be idempotent.
:param bool interruptible: [optional] boolean describing if the task is interruptible.
:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
will be logged as a warning so it should contain information regarding how to update to a newer task.
:param bool cache: [optional] boolean describing if the outputs of this task should be cached and
re-usable.
:param datetime.timedelta timeout: [optional] describes how long the task should be allowed to
run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run
indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout.
:param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job
(in addition to 1 master).
:param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica
spawned for this job (i.e. both for master and workers).
This can be set to a fractional portion of a CPU. Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each
replica spawned for this job (i.e. both for master and workers).
Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory
necessary for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space
for each replica spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to storage_request.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica
spawned for this job (i.e. both for master and workers).
This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified,
it is set equal to cpu_request.
:param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each
replica spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to gpu_request.
:param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory
necessary for each replica spawned for this job (i.e. both for master and workers).
This amount is not guaranteed! If not specified, it is set equal to memory_request.
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
:param cls: This can be used to override the task implementation with a user-defined extension. The class
provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to
inject bespoke logic into the base Flyte programming model.
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
"""
def wrapper(fn):
return (cls or _sdk_pytorch_tasks.SdkPyTorchTask)(
task_function=fn,
task_type=_common_constants.SdkTaskType.PYTORCH_TASK,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
discoverable=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
workers_count=workers_count,
per_replica_storage_request=per_replica_storage_request,
per_replica_cpu_request=per_replica_cpu_request,
per_replica_gpu_request=per_replica_gpu_request,
per_replica_memory_request=per_replica_memory_request,
per_replica_storage_limit=per_replica_storage_limit,
per_replica_cpu_limit=per_replica_cpu_limit,
per_replica_gpu_limit=per_replica_gpu_limit,
per_replica_memory_limit=per_replica_memory_limit,
environment=environment or {}
)

if _task_function:
return wrapper(_task_function)
else:
return wrapper
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
]
},
install_requires=[
"flyteidl>=0.17.27,<1.0.0",
"flyteidl>=0.17.32,<1.0.0",
"click>=6.6,<8.0",
"croniter>=0.3.20,<4.0.0",
"deprecation>=2.0,<3.0",
Expand Down
45 changes: 45 additions & 0 deletions tests/flytekit/unit/sdk/tasks/test_pytorch_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import absolute_import
from flytekit.sdk.tasks import pytorch_task, inputs, outputs
from flytekit.sdk.types import Types
from flytekit.common import constants as _common_constants
from flytekit.common.tasks import sdk_runnable as _sdk_runnable, pytorch_task as _pytorch_task
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier
import datetime as _datetime


@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@pytorch_task(workers_count=1)
def simple_pytorch_task(wf_params, sc, in1, out1):
pass


simple_pytorch_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version")


def test_simple_pytorch_task():
assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask)
assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask)
assert simple_pytorch_task.interface.inputs['in1'].description == ''
assert simple_pytorch_task.interface.inputs['in1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER)
assert simple_pytorch_task.interface.outputs['out1'].description == ''
assert simple_pytorch_task.interface.outputs['out1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.STRING)
assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK
assert simple_pytorch_task.task_function_name == 'simple_pytorch_task'
assert simple_pytorch_task.task_module == __name__
assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_pytorch_task.metadata.deprecated_error_message == ''
assert simple_pytorch_task.metadata.discoverable is False
assert simple_pytorch_task.metadata.discovery_version == ''
assert simple_pytorch_task.metadata.retries.retries == 0
assert len(simple_pytorch_task.container.resources.limits) == 0
assert len(simple_pytorch_task.container.resources.requests) == 0
assert simple_pytorch_task.custom['workers'] == 1
# Should strip out the venv component of the args.
assert simple_pytorch_task._get_container_definition().args[0] == 'pyflyte-execute'

pb2 = simple_pytorch_task.to_flyte_idl()
assert pb2.custom['workers'] == 1

0 comments on commit e908a59

Please sign in to comment.