From 41a9c7a80e61fc737ebfa765d47ab14f7e514fd4 Mon Sep 17 00:00:00 2001 From: bstadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Thu, 12 Jan 2023 03:16:50 +0100 Subject: [PATCH] Add dask plugin #patch (#1366) * Add dummy task type to test backend plugin Signed-off-by: Bernhard Stadlbauer * Add docs page Signed-off-by: Bernhard Stadlbauer * Add dask models Signed-off-by: Bernhard Stadlbauer * Add function to convert resources Signed-off-by: Bernhard Stadlbauer * Add tests to `dask` task Signed-off-by: Bernhard Stadlbauer * Remove namespace Signed-off-by: Bernhard Stadlbauer * Update setup.py Signed-off-by: Bernhard Stadlbauer * Add dask to `plugin/README.md` Signed-off-by: Bernhard Stadlbauer * Add README.md for `dask` Signed-off-by: Bernhard Stadlbauer * Top level export of `JopPodSpec` and `DaskCluster` Signed-off-by: Bernhard Stadlbauer * Update docs for images Signed-off-by: Bernhard Stadlbauer * Update README.md Signed-off-by: Bernhard Stadlbauer * Update models after `flyteidl` change Signed-off-by: Bernhard Stadlbauer * Update task after `flyteidl` change Signed-off-by: Bernhard Stadlbauer * Raise error when less than 1 worker Signed-off-by: Bernhard Stadlbauer * Update flyteidl to >= 1.3.2 Signed-off-by: Bernhard Stadlbauer * Update doc requirements Signed-off-by: Bernhard Stadlbauer * Update doc-requirements.txt Signed-off-by: Bernhard Stadlbauer * Re-lock dependencies on linux Signed-off-by: Bernhard Stadlbauer * Update dask API docs Signed-off-by: Bernhard Stadlbauer * Fix documentation links Signed-off-by: Bernhard Stadlbauer * Default optional model constructor arguments to `None` Signed-off-by: Bernhard Stadlbauer * Refactor `convert_resources_to_resource_model` to `core.resources` Signed-off-by: Bernhard Stadlbauer * Use `convert_resources_to_resource_model` in `core.node` Signed-off-by: Bernhard Stadlbauer * Incorporate review feedback Signed-off-by: Eduardo Apolinario * Lint Signed-off-by: Eduardo Apolinario Signed-off-by: Bernhard Stadlbauer Signed-off-by: Bernhard Stadlbauer Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Co-authored-by: Eduardo Apolinario --- .github/workflows/pythonbuild.yml | 1 + doc-requirements.in | 1 + docs/source/plugins/dask.rst | 12 + docs/source/plugins/index.rst | 2 + flytekit/core/node.py | 16 +- flytekit/core/resources.py | 43 ++- flytekit/core/utils.py | 37 ++- plugins/README.md | 1 + plugins/flytekit-dask/README.md | 21 ++ .../flytekitplugins/dask/__init__.py | 15 ++ .../flytekitplugins/dask/models.py | 134 ++++++++++ .../flytekitplugins/dask/task.py | 108 ++++++++ plugins/flytekit-dask/requirements.in | 2 + plugins/flytekit-dask/requirements.txt | 247 ++++++++++++++++++ plugins/flytekit-dask/setup.py | 42 +++ plugins/flytekit-dask/tests/__init__.py | 0 plugins/flytekit-dask/tests/test_models.py | 96 +++++++ plugins/flytekit-dask/tests/test_task.py | 86 ++++++ tests/flytekit/unit/core/test_resources.py | 68 +++++ 19 files changed, 906 insertions(+), 26 deletions(-) create mode 100644 docs/source/plugins/dask.rst create mode 100644 plugins/flytekit-dask/README.md create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/__init__.py create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/models.py create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/task.py create mode 100644 plugins/flytekit-dask/requirements.in create mode 100644 plugins/flytekit-dask/requirements.txt create mode 100644 plugins/flytekit-dask/setup.py create mode 100644 plugins/flytekit-dask/tests/__init__.py create mode 100644 plugins/flytekit-dask/tests/test_models.py create mode 100644 plugins/flytekit-dask/tests/test_task.py create mode 100644 tests/flytekit/unit/core/test_resources.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index fc8a554cdd..f8b1c76f8e 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -67,6 +67,7 @@ jobs: - flytekit-aws-batch - flytekit-aws-sagemaker - flytekit-bigquery + - flytekit-dask - flytekit-data-fsspec - flytekit-dbt - flytekit-deck-standard diff --git a/doc-requirements.in b/doc-requirements.in index 9fa7c50a1b..713934df13 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -46,5 +46,6 @@ whylogs # whylogs whylabs-client # whylogs ray # ray scikit-learn # scikit-learn +dask[distributed] # dask vaex # vaex mlflow # mlflow diff --git a/docs/source/plugins/dask.rst b/docs/source/plugins/dask.rst new file mode 100644 index 0000000000..53e9f11fcb --- /dev/null +++ b/docs/source/plugins/dask.rst @@ -0,0 +1,12 @@ +.. _dask: + +################################################### +Dask API reference +################################################### + +.. tags:: Integration, DistributedComputing, KubernetesOperator + +.. automodule:: flytekitplugins.dask + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index 008f2b4bbe..693587192e 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -9,6 +9,7 @@ Plugin API reference * :ref:`AWS Sagemaker ` - AWS Sagemaker plugin reference * :ref:`Google Bigquery ` - Google Bigquery plugin reference * :ref:`FS Spec ` - FS Spec API reference +* :ref:`Dask ` - Dask standard API reference * :ref:`Deck standard ` - Deck standard API reference * :ref:`Dolt standard ` - Dolt standard API reference * :ref:`Great expectations ` - Great expectations API reference @@ -40,6 +41,7 @@ Plugin API reference AWS Sagemaker Google Bigquery FS Spec + Dask Deck standard Dolt standard Great expectations diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d8b43f2728..52487e6e48 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,7 +4,7 @@ import typing from typing import Any, List -from flytekit.core.resources import Resources +from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model @@ -92,9 +92,14 @@ def with_overrides(self, *args, **kwargs): for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) if "requests" in kwargs or "limits" in kwargs: - requests = _convert_resource_overrides(kwargs.get("requests"), "requests") - limits = _convert_resource_overrides(kwargs.get("limits"), "limits") - self._resources = _resources_model(requests=requests, limits=limits) + requests = kwargs.get("requests") + if requests and not isinstance(requests, Resources): + raise AssertionError("requests should be specified as flytekit.Resources") + limits = kwargs.get("limits") + if limits and not isinstance(limits, Resources): + raise AssertionError("limits should be specified as flytekit.Resources") + + self._resources = convert_resources_to_resource_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: @@ -122,8 +127,7 @@ def _convert_resource_overrides( ) -> [_resources_model.ResourceEntry]: if resources is None: return [] - if not isinstance(resources, Resources): - raise AssertionError(f"{resource_name} should be specified as flytekit.Resources") + resource_entries = [] if resources.cpu is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu)) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 7b46cbe05c..6280604246 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional + +from flytekit.models import task as task_models @dataclass @@ -35,3 +37,42 @@ class Resources(object): class ResourceSpec(object): requests: Optional[Resources] = None limits: Optional[Resources] = None + + +_ResouceName = task_models.Resources.ResourceName +_ResourceEntry = task_models.Resources.ResourceEntry + + +def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: + resource_entries = [] + if resources.cpu is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.CPU, value=resources.cpu)) + if resources.mem is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.MEMORY, value=resources.mem)) + if resources.gpu is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.GPU, value=resources.gpu)) + if resources.storage is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.STORAGE, value=resources.storage)) + if resources.ephemeral_storage is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) + return resource_entries + + +def convert_resources_to_resource_model( + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, +) -> task_models.Resources: + """ + Convert flytekit ``Resources`` objects to a Resources model + + :param requests: Resource requests. Optional, defaults to ``None`` + :param limits: Resource limits. Optional, defaults to ``None`` + :return: The given resources as requests and limits + """ + request_entries = [] + limit_entries = [] + if requests is not None: + request_entries = _convert_resources_to_resource_entries(requests) + if limits is not None: + limit_entries = _convert_resources_to_resource_entries(limits) + return task_models.Resources(requests=request_entries, limits=limit_entries) diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbb..ae8b89a109 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional from flytekit.loggers import logger -from flytekit.models import task as _task_models +from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -52,7 +52,7 @@ def _get_container_definition( image: str, command: List[str], args: List[str], - data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + data_loading_config: Optional[task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -64,7 +64,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> _task_models.Container: +) -> task_models.Container: storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -76,50 +76,49 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_request) ) if ephemeral_storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_request)) limits = [] if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_limit)) if ephemeral_storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_limit)) if environment is None: environment = {} - return _task_models.Container( + return task_models.Container( image=image, command=command, args=args, - resources=_task_models.Resources(limits=limits, requests=requests), + resources=task_models.Resources(limits=limits, requests=requests), env=environment, config={}, data_loading_config=data_loading_config, diff --git a/plugins/README.md b/plugins/README.md index 447b91a37c..495ce91019 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -7,6 +7,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | |------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| | AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | +| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | | K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | diff --git a/plugins/flytekit-dask/README.md b/plugins/flytekit-dask/README.md new file mode 100644 index 0000000000..9d645bcd27 --- /dev/null +++ b/plugins/flytekit-dask/README.md @@ -0,0 +1,21 @@ +# Flytekit Dask Plugin + +Flyte can execute `dask` jobs natively on a Kubernetes Cluster, which manages the virtual `dask` cluster's lifecycle +(spin-up and tear down). It leverages the open-source Kubernetes Dask Operator and can be enabled without signing up +for any service. This is like running a transient (ephemeral) `dask` cluster - a type of cluster spun up for a specific +task and torn down after completion. This helps in making sure that the Python environment is the same on the job-runner +(driver), scheduler and the workers. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-dask +``` + +To configure Dask in the Flyte deployment's backed, follow +[step 1](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html#step-1-deploy-the-dask-plugin-in-the-flyte-backend) +and +[step 2](https://docs.flyte.org/projects/cookbook/en/latest/auto/auto/integrations/kubernetes/k8s_dask/index.html#step-2-environment-setup) + +An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html) +can be found in the documentation. diff --git a/plugins/flytekit-dask/flytekitplugins/dask/__init__.py b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py new file mode 100644 index 0000000000..ccadf385fc --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py @@ -0,0 +1,15 @@ +""" +.. currentmodule:: flytekitplugins.dask + +This package contains the Python related side of the Dask Plugin + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + Dask + Scheduler + WorkerGroup +""" + +from flytekitplugins.dask.task import Dask, Scheduler, WorkerGroup diff --git a/plugins/flytekit-dask/flytekitplugins/dask/models.py b/plugins/flytekit-dask/flytekitplugins/dask/models.py new file mode 100644 index 0000000000..b833ab660a --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/models.py @@ -0,0 +1,134 @@ +from typing import Optional + +from flyteidl.plugins import dask_pb2 as dask_task + +from flytekit.models import common as common +from flytekit.models import task as task + + +class Scheduler(common.FlyteIdlEntity): + """ + Configuration for the scheduler pod + + :param image: Optional image to use. + :param resources: Optional resources to use. + """ + + def __init__(self, image: Optional[str] = None, resources: Optional[task.Resources] = None): + self._image = image + self._resources = resources + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image for the scheduler pod + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources for the scheduler pod + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskScheduler: + """ + :return: The scheduler spec serialized to protobuf + """ + return dask_task.DaskScheduler( + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class WorkerGroup(common.FlyteIdlEntity): + """ + Configuration for a dask worker group + + :param number_of_workers:Number of workers in the group + :param image: Optional image to use for the pods of the worker group + :param resources: Optional resources to use for the pods of the worker group + """ + + def __init__( + self, + number_of_workers: int, + image: Optional[str] = None, + resources: Optional[task.Resources] = None, + ): + if number_of_workers < 1: + raise ValueError( + f"Each worker group needs to have at least one worker, but {number_of_workers} have been specified." + ) + + self._number_of_workers = number_of_workers + self._image = image + self._resources = resources + + @property + def number_of_workers(self) -> Optional[int]: + """ + :return: Optional number of workers for the worker group + """ + return self._number_of_workers + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image to use for the worker pods + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources to use for the worker pods + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskWorkerGroup: + """ + :return: The dask cluster serialized to protobuf + """ + return dask_task.DaskWorkerGroup( + number_of_workers=self.number_of_workers, + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class DaskJob(common.FlyteIdlEntity): + """ + Configuration for the custom dask job to run + + :param scheduler: Configuration for the scheduler + :param workers: Configuration of the default worker group + """ + + def __init__(self, scheduler: Scheduler, workers: WorkerGroup): + self._scheduler = scheduler + self._workers = workers + + @property + def scheduler(self) -> Scheduler: + """ + :return: Configuration for the scheduler pod + """ + return self._scheduler + + @property + def workers(self) -> WorkerGroup: + """ + :return: Configuration of the default worker group + """ + return self._workers + + def to_flyte_idl(self) -> dask_task.DaskJob: + """ + :return: The dask job serialized to protobuf + """ + return dask_task.DaskJob( + scheduler=self.scheduler.to_flyte_idl(), + workers=self.workers.to_flyte_idl(), + ) diff --git a/plugins/flytekit-dask/flytekitplugins/dask/task.py b/plugins/flytekit-dask/flytekitplugins/dask/task.py new file mode 100644 index 0000000000..830ede98ef --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/task.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from flytekitplugins.dask import models +from google.protobuf.json_format import MessageToDict + +from flytekit import PythonFunctionTask, Resources +from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model +from flytekit.core.task import TaskPlugins + + +@dataclass +class Scheduler: + """ + Configuration for the scheduler pod + + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed and should have the same Python + environment as the rest of the cluster (job runner pod + worker pods). + :param requests: Resources to request for the scheduler pod. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the scheduler pod. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class WorkerGroup: + """ + Configuration for a group of dask worker pods + + :param number_of_workers: Number of workers to use. Optional, defaults to 1. + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed. The provided image should have the + same Python environment as the job runner/driver as well as the scheduler. + :param requests: Resources to request for the worker pods. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the worker pods. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + number_of_workers: Optional[int] = 1 + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class Dask: + """ + Configuration for the dask task + + :param scheduler: Configuration for the scheduler pod. Optional, defaults to ``Scheduler()``. + :param workers: Configuration for the pods of the default worker group. Optional, defaults to ``WorkerGroup()``. + """ + + scheduler: Scheduler = Scheduler() + workers: WorkerGroup = WorkerGroup() + + +class DaskTask(PythonFunctionTask[Dask]): + """ + Actual Plugin that transforms the local python code for execution within a dask cluster + """ + + _DASK_TASK_TYPE = "dask" + + def __init__(self, task_config: Dask, task_function: Callable, **kwargs): + super(DaskTask, self).__init__( + task_config=task_config, + task_type=self._DASK_TASK_TYPE, + task_function=task_function, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + """ + Serialize the `dask` task config into a dict. + + :param settings: Current serialization settings + :return: Dictionary representation of the dask task config. + """ + scheduler = models.Scheduler( + image=self.task_config.scheduler.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.scheduler.requests, + limits=self.task_config.scheduler.limits, + ), + ) + workers = models.WorkerGroup( + number_of_workers=self.task_config.workers.number_of_workers, + image=self.task_config.workers.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.workers.requests, + limits=self.task_config.workers.limits, + ), + ) + job = models.DaskJob(scheduler=scheduler, workers=workers) + return MessageToDict(job.to_flyte_idl()) + + +# Inject the `dask` plugin into flytekits dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(Dask, DaskTask) diff --git a/plugins/flytekit-dask/requirements.in b/plugins/flytekit-dask/requirements.in new file mode 100644 index 0000000000..310ade8617 --- /dev/null +++ b/plugins/flytekit-dask/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-dask diff --git a/plugins/flytekit-dask/requirements.txt b/plugins/flytekit-dask/requirements.txt new file mode 100644 index 0000000000..2ec017e46d --- /dev/null +++ b/plugins/flytekit-dask/requirements.txt @@ -0,0 +1,247 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile --output-file=requirements.txt requirements.in setup.py +# +-e file:.#egg=flytekitplugins-dask + # via -r requirements.in +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.9.24 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # dask + # distributed + # flytekit +cloudpickle==2.2.0 + # via + # dask + # distributed + # flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.7 + # via flytekit +cryptography==38.0.3 + # via + # pyopenssl + # secretstorage +dask[distributed]==2022.10.2 + # via + # distributed + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +distributed==2022.10.2 + # via dask +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.3.2 + # via + # flytekit + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +flytekit==1.3.0b2 + # via + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +fsspec==2022.10.0 + # via dask +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +grpcio==1.51.1 + # via + # flytekit + # grpcio-status +grpcio-status==1.51.1 + # via flytekit +heapdict==1.0.1 + # via zict +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via + # flytekit + # keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 + # via + # cookiecutter + # distributed + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.11.0 + # via flytekit +locket==1.0.0 + # via + # distributed + # partd +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.18.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.0.0 + # via jaraco-classes +msgpack==1.0.4 + # via distributed +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.2.0 + # via flytekit +numpy==1.23.4 + # via + # pandas + # pyarrow +packaging==21.3 + # via + # dask + # distributed + # docker + # marshmallow +pandas==1.5.1 + # via flytekit +partd==1.3.0 + # via dask +protobuf==4.21.11 + # via + # flyteidl + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +psutil==5.9.3 + # via distributed +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.1.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.6 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # dask + # distributed + # flytekit +regex==2022.10.31 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.22.0 + # via flytekit +retry==0.9.2 + # via flytekit +secretstorage==3.3.3 + # via keyring +six==1.16.0 + # via python-dateutil +sortedcontainers==2.4.0 + # via + # distributed + # flytekit +statsd==3.3.0 + # via flytekit +tblib==1.7.0 + # via distributed +text-unidecode==1.3 + # via python-slugify +toml==0.10.2 + # via responses +toolz==0.12.0 + # via + # dask + # distributed + # partd +tornado==6.1 + # via distributed +types-toml==0.10.8 + # via responses +typing-extensions==4.4.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # distributed + # docker + # flytekit + # requests + # responses +websocket-client==1.4.2 + # via docker +wheel==0.38.2 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zict==2.2.0 + # via distributed +zipp==3.10.0 + # via importlib-metadata diff --git a/plugins/flytekit-dask/setup.py b/plugins/flytekit-dask/setup.py new file mode 100644 index 0000000000..440d7b47db --- /dev/null +++ b/plugins/flytekit-dask/setup.py @@ -0,0 +1,42 @@ +from setuptools import setup + +PLUGIN_NAME = "dask" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flyteidl>=1.3.2", + "flytekit>=1.3.0b2,<2.0.0", + "dask[distributed]>=2022.10.2", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Dask plugin for flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-dask", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", # dask requires >= 3.8 + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-dask/tests/__init__.py b/plugins/flytekit-dask/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-dask/tests/test_models.py b/plugins/flytekit-dask/tests/test_models.py new file mode 100644 index 0000000000..801a110fb1 --- /dev/null +++ b/plugins/flytekit-dask/tests/test_models.py @@ -0,0 +1,96 @@ +import pytest +from flytekitplugins.dask import models + +from flytekit.models import task as _task + + +@pytest.fixture +def image() -> str: + return "foo:latest" + + +@pytest.fixture +def resources() -> _task.Resources: + return _task.Resources( + requests=[ + _task.Resources.ResourceEntry(name=_task.Resources.ResourceName.CPU, value="3"), + ], + limits=[], + ) + + +@pytest.fixture +def default_resources() -> _task.Resources: + return _task.Resources(requests=[], limits=[]) + + +@pytest.fixture +def scheduler(image: str, resources: _task.Resources) -> models.Scheduler: + return models.Scheduler(image=image, resources=resources) + + +@pytest.fixture +def workers(image: str, resources: _task.Resources) -> models.WorkerGroup: + return models.WorkerGroup(number_of_workers=123, image=image, resources=resources) + + +def test_create_scheduler_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_create_scheduler_to_flyte_idl_all_optional(default_resources: _task.Resources): + scheduler = models.Scheduler(image=None, resources=None) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_create_scheduler_spec_property_access(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + assert scheduler.image == image + assert scheduler.resources == resources + + +def test_worker_group_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + idl_object = worker_group.to_flyte_idl() + assert idl_object.number_of_workers == n_workers + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_worker_group_to_flyte_idl_all_optional(default_resources: _task.Resources): + worker_group = models.WorkerGroup(number_of_workers=1, image=None, resources=None) + idl_object = worker_group.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_worker_group_property_access(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + assert worker_group.image == image + assert worker_group.number_of_workers == n_workers + assert worker_group.resources == resources + + +def test_worker_group_fails_for_less_than_one_worker(): + with pytest.raises(ValueError, match=r"Each worker group needs to"): + models.WorkerGroup(number_of_workers=0, image=None, resources=None) + + +def test_dask_job_to_flyte_idl_no_optional(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + idl_object = job.to_flyte_idl() + assert idl_object.scheduler == scheduler.to_flyte_idl() + assert idl_object.workers == workers.to_flyte_idl() + + +def test_dask_job_property_access(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + assert job.scheduler == scheduler + assert job.workers == workers diff --git a/plugins/flytekit-dask/tests/test_task.py b/plugins/flytekit-dask/tests/test_task.py new file mode 100644 index 0000000000..76dbf9d048 --- /dev/null +++ b/plugins/flytekit-dask/tests/test_task.py @@ -0,0 +1,86 @@ +import pytest +from flytekitplugins.dask import Dask, Scheduler, WorkerGroup + +from flytekit import PythonFunctionTask, Resources, task +from flytekit.configuration import Image, ImageConfig, SerializationSettings + + +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_dask_task_with_default_config(serialization_settings: SerializationSettings): + task_config = Dask() + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + assert dask_task.task_config == task_config + assert dask_task.task_type == "dask" + + expected_dict = { + "scheduler": { + "resources": {}, + }, + "workers": { + "numberOfWorkers": 1, + "resources": {}, + }, + } + assert dask_task.get_custom(serialization_settings) == expected_dict + + +def test_dask_task_get_custom(serialization_settings: SerializationSettings): + task_config = Dask( + scheduler=Scheduler( + image="scheduler:latest", + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + workers=WorkerGroup( + number_of_workers=123, + image="dask_cluster:latest", + requests=Resources(cpu="3"), + limits=Resources(cpu="4"), + ), + ) + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + expected_custom_dict = { + "scheduler": { + "image": "scheduler:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workers": { + "numberOfWorkers": 123, + "image": "dask_cluster:latest", + "resources": { + "requests": [{"name": "CPU", "value": "3"}], + "limits": [{"name": "CPU", "value": "4"}], + }, + }, + } + custom_dict = dask_task.get_custom(serialization_settings) + assert custom_dict == expected_custom_dict diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py new file mode 100644 index 0000000000..1a3bf64dee --- /dev/null +++ b/tests/flytekit/unit/core/test_resources.py @@ -0,0 +1,68 @@ +from typing import Dict + +import pytest + +import flytekit.models.task as _task_models +from flytekit import Resources +from flytekit.core.resources import convert_resources_to_resource_model + +_ResourceName = _task_models.Resources.ResourceName + + +def test_convert_no_requests_no_limits(): + resource_model = convert_resources_to_resource_model(requests=None, limits=None) + assert isinstance(resource_model, _task_models.Resources) + assert resource_model.requests == [] + assert resource_model.limits == [] + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_requests(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(requests=requests) + + assert len(resources_model.requests) == 1 + request = resources_model.requests[0] + assert isinstance(request, _task_models.Resources.ResourceEntry) + assert request.name == expected_resource_name + assert request.value == expected_resource_value + assert len(resources_model.limits) == 0 + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(limits=requests) + + assert len(resources_model.limits) == 1 + limit = resources_model.limits[0] + assert isinstance(limit, _task_models.Resources.ResourceEntry) + assert limit.name == expected_resource_name + assert limit.value == expected_resource_value + assert len(resources_model.requests) == 0