Skip to content

Commit

Permalink
Tensorflow plugin implementation (from PR flyteorg#141) (flyteorg#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored Sep 23, 2020
1 parent 463772e commit 8d5dd75
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 1 deletion.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ If `@pytorch_task` is to be used, one should install the `pytorch` plugin.
pip install "flytekit[pytorch]"
```

### TensorFlow

If `@tensorflow_task` is to be used, one should install the `tensorflow` plugin.

```bash
pip install flytekit[tensorflow]
```

### Full Installation

To install all or multiple available plugins, one can specify them individually:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import flytekit.plugins # noqa: F401

__version__ = "0.13.0b5"
__version__ = "0.13.0b6"

logger = _logging.getLogger("flytekit")

Expand Down
1 change: 1 addition & 0 deletions flytekit/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SdkTaskType(object):
SENSOR_TASK = "sensor-task"
PRESTO_TASK = "presto"
PYTORCH_TASK = "pytorch"
TENSORFLOW_TASK = "tensorflow"
# Raw container task is just a name, it defaults to using the regular container task (like python etc), but sets the data_config in the container
RAW_CONTAINER_TASK = "raw-container"
SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task"
Expand Down
69 changes: 69 additions & 0 deletions flytekit/common/tasks/tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from google.protobuf.json_format import MessageToDict as _MessageToDict

from flytekit.common.tasks import sdk_runnable as _sdk_runnable
from flytekit.models import task as _task_models


class SdkRunnableTensorflowContainer(_sdk_runnable.SdkRunnableContainer):
@property
def args(self):
"""
Override args to remove the injection of command prefixes
:rtype: list[Text]
"""
return self._args


class SdkTensorFlowTask(_sdk_runnable.SdkRunnableTask):
def __init__(
self,
task_function,
task_type,
cache_version,
retries,
interruptible,
deprecated,
cache,
timeout,
workers_count,
ps_replicas_count,
chief_replicas_count,
per_replica_storage_request,
per_replica_cpu_request,
per_replica_gpu_request,
per_replica_memory_request,
per_replica_storage_limit,
per_replica_cpu_limit,
per_replica_gpu_limit,
per_replica_memory_limit,
environment,
):
tensorflow_job = _task_models.TensorFlowJob(
workers_count=workers_count, ps_replicas_count=ps_replicas_count, chief_replicas_count=chief_replicas_count
).to_flyte_idl()
super(SdkTensorFlowTask, self).__init__(
task_function=task_function,
task_type=task_type,
discovery_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
storage_request=per_replica_storage_request,
cpu_request=per_replica_cpu_request,
gpu_request=per_replica_gpu_request,
memory_request=per_replica_memory_request,
storage_limit=per_replica_storage_limit,
cpu_limit=per_replica_cpu_limit,
gpu_limit=per_replica_gpu_limit,
memory_limit=per_replica_memory_limit,
discoverable=cache,
timeout=timeout,
environment=environment,
custom=_MessageToDict(tensorflow_job),
)

def _get_container_definition(self, **kwargs):
"""
:rtype: SdkRunnableTensorflowContainer
"""
return super(SdkTensorFlowTask, self)._get_container_definition(cls=SdkRunnableTensorflowContainer, **kwargs)
33 changes: 33 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flyteidl.core import tasks_pb2 as _core_task
from flyteidl.plugins import pytorch_pb2 as _pytorch_task
from flyteidl.plugins import spark_pb2 as _spark_task
from flyteidl.plugins import tensorflow_pb2 as _tensorflow_task
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct

Expand Down Expand Up @@ -881,3 +882,35 @@ def to_flyte_idl(self):
@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(workers_count=pb2_object.workers,)


class TensorFlowJob(_common.FlyteIdlEntity):
def __init__(self, workers_count, ps_replicas_count, chief_replicas_count):
self._workers_count = workers_count
self._ps_replicas_count = ps_replicas_count
self._chief_replicas_count = chief_replicas_count

@property
def workers_count(self):
return self._workers_count

@property
def ps_replicas_count(self):
return self._ps_replicas_count

@property
def chief_replicas_count(self):
return self._chief_replicas_count

def to_flyte_idl(self):
return _tensorflow_task.DistributedTensorflowTrainingTask(
workers=self.workers_count, ps_replicas=self.ps_replicas_count, chief_replicas=self.chief_replicas_count
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(
workers_count=pb2_object.workers,
ps_replicas_count=pb2_object.ps_replicas,
chief_replicas_count=pb2_object.chief_replicas,
)
2 changes: 2 additions & 0 deletions flytekit/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

torch = _lazy_loader.lazy_load_module("torch") # type: _lazy_loader._LazyLoadModule

tensorflow = _lazy_loader.lazy_load_module("tensorflow") # type: _lazy_loader._LazyLoadModule

_lazy_loader.LazyLoadPlugin("spark", ["pyspark>=2.4.0,<3.0.0"], [pyspark])

_lazy_loader.LazyLoadPlugin("spark3", ["pyspark>=3.0.0"], [pyspark])
Expand Down
165 changes: 165 additions & 0 deletions flytekit/sdk/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from flytekit.common.tasks import sidecar_task as _sdk_sidecar_tasks
from flytekit.common.tasks import spark_task as _sdk_spark_tasks
from flytekit.common.tasks import task as _task
from flytekit.common.tasks import tensorflow_task as _sdk_tensorflow_tasks
from flytekit.common.types import helpers as _type_helpers
from flytekit.contrib.notebook import tasks as _nb_tasks
from flytekit.models import interface as _interface_model
Expand Down Expand Up @@ -1336,3 +1337,167 @@ def wrapper(fn):
return wrapper(_task_function)
else:
return wrapper


def tensorflow_task(
_task_function=None,
cache_version="",
retries=0,
interruptible=False,
deprecated="",
cache=False,
timeout=None,
workers_count=1,
ps_replicas_count=None,
chief_replicas_count=None,
per_replica_storage_request="",
per_replica_cpu_request="",
per_replica_gpu_request="",
per_replica_memory_request="",
per_replica_storage_limit="",
per_replica_cpu_limit="",
per_replica_gpu_limit="",
per_replica_memory_limit="",
environment=None,
cls=None,
):
"""
Decorator to create a Tensorflow Task definition. This task will submit TFJob (see https://github.com/kubeflow/tf-operator)
defined by the code within the _task_function to k8s cluster.
.. code-block:: python
@inputs(int_list=[Types.Integer])
@outputs(result=Types.Integer
@tensorflow_task(
workers_count=2,
ps_replicas_count=1,
chief_replicas_count=1,
per_replica_cpu_request="500m",
per_replica_memory_request="4Gi",
per_replica_memory_limit="8Gi",
per_replica_gpu_limit="1",
)
def my_tensorflow_job(wf_params, int_list, result):
pass
:param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must
take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword
arguments are allowed for wrapped task functions.
:param Text cache_version: [optional] string representing logical version for discovery. This field should be
updated whenever the underlying algorithm changes.
.. note::
This argument is required to be a non-empty string if `cache` is True.
:param int retries: [optional] integer determining number of times task can be retried on
:py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults
to 0.
.. note::
If retries > 0, the task must be able to recover from any remote state created within the user code. It is
strongly recommended that tasks are written to be idempotent.
:param bool interruptible: [optional] boolean describing if the task is interruptible.
:param Text deprecated: [optional] string that should be provided if this task is deprecated. The string
will be logged as a warning so it should contain information regarding how to update to a newer task.
:param bool cache: [optional] boolean describing if the outputs of this task should be cached and
re-usable.
:param datetime.timedelta timeout: [optional] describes how long the task should be allowed to
run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run
indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout.
:param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job
:param int ps_replicas_count: integer determining the number of parameter server replicas spawned in the cluster for this job
:param int chief_replicas_count: integer determining the number of chief server replicas spawned in the cluster for this job
:param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space
for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica
spawned for this job (i.e. both for parameter, chief server and workers).
This can be set to a fractional portion of a CPU. Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each
replica spawned for this job (i.e. both for parameter, chief server and workers).
Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory
necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers). Default is set by platform-level configuration.
TODO: Add links to resource string documentation for Kubernetes
:param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space
for each replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to storage_request.
.. note::
This is currently not supported by the platform.
:param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica
spawned for this job (i.e. both for parameter, chief server and workers).
This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified,
it is set equal to cpu_request.
:param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each
replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to gpu_request.
:param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory
necessary for each replica spawned for this job (i.e. both for parameter, chief server and workers).
This amount is not guaranteed! If not specified, it is set equal to memory_request.
:param dict[Text,Text] environment: [optional] environment variables to set when executing this task.
:param cls: This can be used to override the task implementation with a user-defined extension. The class
provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to
inject bespoke logic into the base Flyte programming model.
:rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask
"""

def wrapper(fn):
return (cls or _sdk_tensorflow_tasks.SdkTensorFlowTask)(
task_function=fn,
task_type=_common_constants.SdkTaskType.TENSORFLOW_TASK,
cache_version=cache_version,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
cache=cache,
timeout=timeout or _datetime.timedelta(seconds=0),
workers_count=workers_count,
ps_replicas_count=ps_replicas_count,
chief_replicas_count=chief_replicas_count,
per_replica_storage_request=per_replica_storage_request,
per_replica_cpu_request=per_replica_cpu_request,
per_replica_gpu_request=per_replica_gpu_request,
per_replica_memory_request=per_replica_memory_request,
per_replica_storage_limit=per_replica_storage_limit,
per_replica_cpu_limit=per_replica_cpu_limit,
per_replica_gpu_limit=per_replica_gpu_limit,
per_replica_memory_limit=per_replica_memory_limit,
environment=environment or {},
)

if _task_function:
return wrapper(_task_function)
else:
return wrapper
55 changes: 55 additions & 0 deletions tests/flytekit/unit/sdk/tasks/test_tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import datetime as _datetime

from flytekit.common import constants as _common_constants
from flytekit.common.tasks import sdk_runnable as _sdk_runnable
from flytekit.common.tasks import tensorflow_task as _tensorflow_task
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier
from flytekit.sdk.tasks import inputs, outputs, tensorflow_task
from flytekit.sdk.types import Types


@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@tensorflow_task(workers_count=2, ps_replicas_count=1, chief_replicas_count=1)
def simple_tensorflow_task(wf_params, sc, in1, out1):
pass


simple_tensorflow_task._id = _identifier.Identifier(
_identifier.ResourceType.TASK, "project", "domain", "name", "version"
)


def test_simple_tensorflow_task():
assert isinstance(simple_tensorflow_task, _tensorflow_task.SdkTensorFlowTask)
assert isinstance(simple_tensorflow_task, _sdk_runnable.SdkRunnableTask)
assert simple_tensorflow_task.interface.inputs["in1"].description == ""
assert simple_tensorflow_task.interface.inputs["in1"].type == _type_models.LiteralType(
simple=_type_models.SimpleType.INTEGER
)
assert simple_tensorflow_task.interface.outputs["out1"].description == ""
assert simple_tensorflow_task.interface.outputs["out1"].type == _type_models.LiteralType(
simple=_type_models.SimpleType.STRING
)
assert simple_tensorflow_task.type == _common_constants.SdkTaskType.TENSORFLOW_TASK
assert simple_tensorflow_task.task_function_name == "simple_tensorflow_task"
assert simple_tensorflow_task.task_module == __name__
assert simple_tensorflow_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert simple_tensorflow_task.metadata.deprecated_error_message == ""
assert simple_tensorflow_task.metadata.discoverable is False
assert simple_tensorflow_task.metadata.discovery_version == ""
assert simple_tensorflow_task.metadata.retries.retries == 0
assert len(simple_tensorflow_task.container.resources.limits) == 0
assert len(simple_tensorflow_task.container.resources.requests) == 0
assert simple_tensorflow_task.custom["workers"] == 2
assert simple_tensorflow_task.custom["psReplicas"] == 1
assert simple_tensorflow_task.custom["chiefReplicas"] == 1

# Should strip out the venv component of the args.
assert simple_tensorflow_task._get_container_definition().args[0] == "pyflyte-execute"

pb2 = simple_tensorflow_task.to_flyte_idl()
assert pb2.custom["workers"] == 2
assert pb2.custom["psReplicas"] == 1
assert pb2.custom["chiefReplicas"] == 1

0 comments on commit 8d5dd75

Please sign in to comment.