From e908a59677477fa831b55ffd1f8a274e0ede1aee Mon Sep 17 00:00:00 2001 From: Igor Valko Date: Fri, 29 May 2020 01:23:12 +0300 Subject: [PATCH] pytorch plugin implementation (#112) --- README.md | 8 + flytekit/common/constants.py | 1 + flytekit/common/tasks/pytorch_task.py | 80 +++++++++ flytekit/models/task.py | 22 +++ flytekit/plugins/__init__.py | 8 + flytekit/sdk/tasks.py | 157 +++++++++++++++++- setup.py | 2 +- .../unit/sdk/tasks/test_pytorch_task.py | 45 +++++ 8 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 flytekit/common/tasks/pytorch_task.py create mode 100644 tests/flytekit/unit/sdk/tasks/test_pytorch_task.py diff --git a/README.md b/README.md index a793b7eab46..9e4be08a880 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,14 @@ If `@sidecar_task` is to be used, one should install the `sidecar` plugin. pip install flytekit[sidecar] ``` +### Pytorch + +If `@pytorch_task` is to be used, one should install the `pytorch` plugin. + +```bash +pip install flytekit[pytorch] +``` + ### Full Installation To install all or multiple available plugins, one can specify them individually: diff --git a/flytekit/common/constants.py b/flytekit/common/constants.py index 2016d0b00a4..34c30e2cad1 100644 --- a/flytekit/common/constants.py +++ b/flytekit/common/constants.py @@ -21,6 +21,7 @@ class SdkTaskType(object): SIDECAR_TASK = "sidecar" SENSOR_TASK = "sensor-task" PRESTO_TASK = "presto" + PYTORCH_TASK = "pytorch" GLOBAL_INPUT_NODE_ID = '' diff --git a/flytekit/common/tasks/pytorch_task.py b/flytekit/common/tasks/pytorch_task.py new file mode 100644 index 00000000000..eabb88d2cf6 --- /dev/null +++ b/flytekit/common/tasks/pytorch_task.py @@ -0,0 +1,80 @@ +from __future__ import absolute_import + +try: + from inspect import getfullargspec as _getargspec +except ImportError: + from inspect import getargspec as _getargspec + +import six as _six +from flytekit.common import constants as _constants +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable +from flytekit.common.types import helpers as _type_helpers +from flytekit.models import literals as _literal_models, task as _task_models +from google.protobuf.json_format import MessageToDict as _MessageToDict + + +class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer): + + @property + def args(self): + """ + Override args to remove the injection of command prefixes + :rtype: list[Text] + """ + return self._args + +class SdkPyTorchTask(_sdk_runnable.SdkRunnableTask): + def __init__( + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + discoverable, + timeout, + workers_count, + per_replica_storage_request, + per_replica_cpu_request, + per_replica_gpu_request, + per_replica_memory_request, + per_replica_storage_limit, + per_replica_cpu_limit, + per_replica_gpu_limit, + per_replica_memory_limit, + environment + ): + pytorch_job = _task_models.PyTorchJob( + workers_count=workers_count + ).to_flyte_idl() + super(SdkPyTorchTask, self).__init__( + task_function=task_function, + task_type=task_type, + discovery_version=discovery_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + storage_request=per_replica_storage_request, + cpu_request=per_replica_cpu_request, + gpu_request=per_replica_gpu_request, + memory_request=per_replica_memory_request, + storage_limit=per_replica_storage_limit, + cpu_limit=per_replica_cpu_limit, + gpu_limit=per_replica_gpu_limit, + memory_limit=per_replica_memory_limit, + discoverable=discoverable, + timeout=timeout, + environment=environment, + custom=_MessageToDict(pytorch_job) + ) + + def _get_container_definition( + self, + **kwargs + ): + """ + :rtype: SdkRunnablePytorchContainer + """ + return super(SdkPyTorchTask, self)._get_container_definition(cls=SdkRunnablePytorchContainer, **kwargs) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 6783740a855..67187b50704 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -6,6 +6,7 @@ from flyteidl.admin import task_pb2 as _admin_task from flyteidl.core import tasks_pb2 as _core_task, literals_pb2 as _literals_pb2, compiler_pb2 as _compiler from flyteidl.plugins import spark_pb2 as _spark_task +from flyteidl.plugins import pytorch_pb2 as _pytorch_task from flytekit.plugins import flyteidl as _lazy_flyteidl from google.protobuf import json_format as _json_format, struct_pb2 as _struct from flytekit.sdk.spark_types import SparkType as _spark_type @@ -804,3 +805,24 @@ def from_flyte_idl(cls, pb2_object): pod_spec=pb2_object.pod_spec, primary_container_name=pb2_object.primary_container_name, ) + + +class PyTorchJob(_common.FlyteIdlEntity): + + def __init__(self, workers_count): + self._workers_count = workers_count + + @property + def workers_count(self): + return self._workers_count + + def to_flyte_idl(self): + return _pytorch_task.DistributedPyTorchTrainingTask( + workers=self.workers_count, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + return cls( + workers_count=pb2_object.workers, + ) diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index 56c2da7cc4a..3235244d79f 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -17,6 +17,8 @@ hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: types.ModuleType type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes") +torch = _lazy_loader.lazy_load_module("torch") # type: types.ModuleType + _lazy_loader.LazyLoadPlugin( "spark", ["pyspark>=2.4.0,<3.0.0"], @@ -46,3 +48,9 @@ ], [hmsclient] ) + +_lazy_loader.LazyLoadPlugin( + "pytorch", + ["torch>=1.0.0,<2.0.0"], + [torch] +) \ No newline at end of file diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index bd3b8fb8a1f..fbd3beaa172 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -6,7 +6,8 @@ from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \ - spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, sidecar_task as _sdk_sidecar_tasks + spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, \ + sidecar_task as _sdk_sidecar_tasks, pytorch_task as _sdk_pytorch_tasks from flytekit.common.tasks import task as _task from flytekit.common.types import helpers as _type_helpers from flytekit.sdk.spark_types import SparkType as _spark_type @@ -1000,3 +1001,157 @@ def wrapper(fn): return wrapper(_task_function) else: return wrapper + + +def pytorch_task( + _task_function=None, + cache_version='', + retries=0, + interruptible=False, + deprecated='', + cache=False, + timeout=None, + workers_count=1, + per_replica_storage_request="", + per_replica_cpu_request="", + per_replica_gpu_request="", + per_replica_memory_request="", + per_replica_storage_limit="", + per_replica_cpu_limit="", + per_replica_gpu_limit="", + per_replica_memory_limit="", + environment=None, + cls=None +): + """ + Decorator to create a Pytorch Task definition. This task will submit PyTorchJob (see https://github.com/kubeflow/pytorch-operator) + defined by the code within the _task_function to k8s cluster. + + .. code-block:: python + + @inputs(int_list=[Types.Integer]) + @outputs(result=Types.Integer + @pytorch_task( + workers_count=2, + per_replica_cpu_request="500m", + per_replica_memory_request="4Gi", + per_replica_memory_limit="8Gi", + per_replica_gpu_limit="1", + ) + def my_pytorch_job(wf_params, int_list, result): + pass + + :param _task_function: this is the decorated method and shouldn't be declared explicitly. The function must + take a first argument, and then named arguments matching those defined in @inputs and @outputs. No keyword + arguments are allowed for wrapped task functions. + + :param Text cache_version: [optional] string representing logical version for discovery. This field should be + updated whenever the underlying algorithm changes. + + .. note:: + + This argument is required to be a non-empty string if `cache` is True. + + :param int retries: [optional] integer determining number of times task can be retried on + :py:exc:`flytekit.sdk.exceptions.RecoverableException` or transient platform failures. Defaults + to 0. + + .. note:: + + If retries > 0, the task must be able to recover from any remote state created within the user code. It is + strongly recommended that tasks are written to be idempotent. + + :param bool interruptible: [optional] boolean describing if the task is interruptible. + + :param Text deprecated: [optional] string that should be provided if this task is deprecated. The string + will be logged as a warning so it should contain information regarding how to update to a newer task. + + :param bool cache: [optional] boolean describing if the outputs of this task should be cached and + re-usable. + + :param datetime.timedelta timeout: [optional] describes how long the task should be allowed to + run at max before triggering a retry (if retries are enabled). By default, tasks are allowed to run + indefinitely. If a null timedelta is passed (i.e. timedelta(seconds=0)), the task will not timeout. + + :param int workers_count: integer determining the number of worker replicas spawned in the cluster for this job + (in addition to 1 master). + + :param Text per_replica_storage_request: [optional] Kubernetes resource string for lower-bound of disk storage space + for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration. + + .. note:: + + This is currently not supported by the platform. + + :param Text per_replica_cpu_request: [optional] Kubernetes resource string for lower-bound of cores for each replica + spawned for this job (i.e. both for master and workers). + This can be set to a fractional portion of a CPU. Default is set by platform-level configuration. + + TODO: Add links to resource string documentation for Kubernetes + + :param Text per_replica_gpu_request: [optional] Kubernetes resource string for lower-bound of desired GPUs for each + replica spawned for this job (i.e. both for master and workers). + Default is set by platform-level configuration. + + TODO: Add links to resource string documentation for Kubernetes + + :param Text per_replica_memory_request: [optional] Kubernetes resource string for lower-bound of physical memory + necessary for each replica spawned for this job (i.e. both for master and workers). Default is set by platform-level configuration. + + TODO: Add links to resource string documentation for Kubernetes + + :param Text per_replica_storage_limit: [optional] Kubernetes resource string for upper-bound of disk storage space + for each replica spawned for this job (i.e. both for master and workers). + This amount is not guaranteed! If not specified, it is set equal to storage_request. + + .. note:: + + This is currently not supported by the platform. + + :param Text per_replica_cpu_limit: [optional] Kubernetes resource string for upper-bound of cores for each replica + spawned for this job (i.e. both for master and workers). + This can be set to a fractional portion of a CPU. This amount is not guaranteed! If not specified, + it is set equal to cpu_request. + + :param Text per_replica_gpu_limit: [optional] Kubernetes resource string for upper-bound of desired GPUs for each + replica spawned for this job (i.e. both for master and workers). + This amount is not guaranteed! If not specified, it is set equal to gpu_request. + + :param Text per_replica_memory_limit: [optional] Kubernetes resource string for upper-bound of physical memory + necessary for each replica spawned for this job (i.e. both for master and workers). + This amount is not guaranteed! If not specified, it is set equal to memory_request. + + :param dict[Text,Text] environment: [optional] environment variables to set when executing this task. + + :param cls: This can be used to override the task implementation with a user-defined extension. The class + provided must be a subclass of flytekit.common.tasks.sdk_runnable.SdkRunnableTask. A user can use this to + inject bespoke logic into the base Flyte programming model. + + :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask + """ + def wrapper(fn): + return (cls or _sdk_pytorch_tasks.SdkPyTorchTask)( + task_function=fn, + task_type=_common_constants.SdkTaskType.PYTORCH_TASK, + discovery_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + discoverable=cache, + timeout=timeout or _datetime.timedelta(seconds=0), + workers_count=workers_count, + per_replica_storage_request=per_replica_storage_request, + per_replica_cpu_request=per_replica_cpu_request, + per_replica_gpu_request=per_replica_gpu_request, + per_replica_memory_request=per_replica_memory_request, + per_replica_storage_limit=per_replica_storage_limit, + per_replica_cpu_limit=per_replica_cpu_limit, + per_replica_gpu_limit=per_replica_gpu_limit, + per_replica_memory_limit=per_replica_memory_limit, + environment=environment or {} + ) + + if _task_function: + return wrapper(_task_function) + else: + return wrapper diff --git a/setup.py b/setup.py index 3abcf015932..30edd543aae 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ ] }, install_requires=[ - "flyteidl>=0.17.27,<1.0.0", + "flyteidl>=0.17.32,<1.0.0", "click>=6.6,<8.0", "croniter>=0.3.20,<4.0.0", "deprecation>=2.0,<3.0", diff --git a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py new file mode 100644 index 00000000000..62cba7cba5b --- /dev/null +++ b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from flytekit.sdk.tasks import pytorch_task, inputs, outputs +from flytekit.sdk.types import Types +from flytekit.common import constants as _common_constants +from flytekit.common.tasks import sdk_runnable as _sdk_runnable, pytorch_task as _pytorch_task +from flytekit.models import types as _type_models +from flytekit.models.core import identifier as _identifier +import datetime as _datetime + + +@inputs(in1=Types.Integer) +@outputs(out1=Types.String) +@pytorch_task(workers_count=1) +def simple_pytorch_task(wf_params, sc, in1, out1): + pass + + +simple_pytorch_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") + + +def test_simple_pytorch_task(): + assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask) + assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask) + assert simple_pytorch_task.interface.inputs['in1'].description == '' + assert simple_pytorch_task.interface.inputs['in1'].type == \ + _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) + assert simple_pytorch_task.interface.outputs['out1'].description == '' + assert simple_pytorch_task.interface.outputs['out1'].type == \ + _type_models.LiteralType(simple=_type_models.SimpleType.STRING) + assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK + assert simple_pytorch_task.task_function_name == 'simple_pytorch_task' + assert simple_pytorch_task.task_module == __name__ + assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(seconds=0) + assert simple_pytorch_task.metadata.deprecated_error_message == '' + assert simple_pytorch_task.metadata.discoverable is False + assert simple_pytorch_task.metadata.discovery_version == '' + assert simple_pytorch_task.metadata.retries.retries == 0 + assert len(simple_pytorch_task.container.resources.limits) == 0 + assert len(simple_pytorch_task.container.resources.requests) == 0 + assert simple_pytorch_task.custom['workers'] == 1 + # Should strip out the venv component of the args. + assert simple_pytorch_task._get_container_definition().args[0] == 'pyflyte-execute' + + pb2 = simple_pytorch_task.to_flyte_idl() + assert pb2.custom['workers'] == 1 \ No newline at end of file