From 581a5c66b6dec1d105f8f655918c405665ceebf6 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Tue, 10 Jan 2023 18:26:55 -0500 Subject: [PATCH 01/13] Update default config to work out-of-the-box with flytectl demo (#1384) Signed-off-by: Niels Bantilan --- flytekit/clis/sdk_in_container/helpers.py | 15 ++++++---- flytekit/configuration/__init__.py | 10 +++---- flytekit/configuration/file.py | 12 ++++++-- .../unit/cli/pyflyte/test_register.py | 2 ++ .../unit/configuration/configs/good.config | 2 +- .../unit/configuration/configs/nossl.yaml | 4 +++ .../flytekit/unit/configuration/test_file.py | 28 ++++++++++++++++++- .../unit/configuration/test_internal.py | 6 ++++ .../unit/configuration/test_yaml_file.py | 9 ++++++ tests/flytekit/unit/remote/test_remote.py | 2 +- 10 files changed, 74 insertions(+), 16 deletions(-) create mode 100644 tests/flytekit/unit/configuration/configs/nossl.yaml diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 72246bcba4e..6ac451be92a 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -4,7 +4,7 @@ import click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config, ImageConfig +from flytekit.configuration import Config, ImageConfig, get_config_file from flytekit.loggers import cli_logger from flytekit.remote.remote import FlyteRemote @@ -25,10 +25,15 @@ def get_and_save_remote_with_click_context( :return: FlyteRemote instance """ cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) + cfg_file = get_config_file(cfg_file_location) + if cfg_file is None: + cfg_obj = Config.for_sandbox() + cli_logger.info("No config files found, creating remote with sandbox config") + else: + cfg_obj = Config.auto(cfg_file_location) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") + ) r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 5d65f8c3ca1..220f9209ea5 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -300,7 +300,7 @@ class PlatformConfig(object): :param endpoint: DNS for Flyte backend :param insecure: Whether or not to use SSL :param insecure_skip_verify: Wether to skip SSL certificate verification - :param console_endpoint: endpoint for console if differenet than Flyte backend + :param console_endpoint: endpoint for console if different than Flyte backend :param command: This command is executed to return a token using an external process. :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. @@ -311,7 +311,7 @@ class PlatformConfig(object): :param auth_mode: The OAuth mode to use. Defaults to pkce flow. """ - endpoint: str = "localhost:30081" + endpoint: str = "localhost:30080" insecure: bool = False insecure_skip_verify: bool = False console_endpoint: typing.Optional[str] = None @@ -529,7 +529,7 @@ def with_params( ) @classmethod - def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> Config: + def auto(cls, config_file: typing.Union[str, ConfigFile, None] = None) -> Config: """ Automatically constructs the Config Object. The order of precedence is as follows 1. first try to find any env vars that match the config vars specified in the FLYTE_CONFIG format. @@ -558,9 +558,9 @@ def for_sandbox(cls) -> Config: :return: Config """ return Config( - platform=PlatformConfig(endpoint="localhost:30081", auth_mode="Pkce", insecure=True), + platform=PlatformConfig(endpoint="localhost:30080", auth_mode="Pkce", insecure=True), data_config=DataConfig( - s3=S3Config(endpoint="http://localhost:30084", access_key_id="minio", secret_access_key="miniostorage") + s3=S3Config(endpoint="http://localhost:30002", access_key_id="minio", secret_access_key="miniostorage") ), ) diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index 793917cffe7..23210e95f1a 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -18,6 +18,11 @@ FLYTECTL_CONFIG_ENV_VAR = "FLYTECTL_CONFIG" +def _exists(val: typing.Any) -> bool: + """Check if a value is defined.""" + return isinstance(val, bool) or bool(val is not None and val) + + @dataclass class LegacyConfigEntry(object): """ @@ -63,7 +68,7 @@ def read_from_file( @dataclass class YamlConfigEntry(object): """ - Creates a record for the config entry. contains + Creates a record for the config entry. Args: switch: dot-delimited string that should match flytectl args. Leaving it as dot-delimited instead of a list of strings because it's easier to maintain alignment with flytectl. @@ -80,10 +85,11 @@ def read_from_file( return None try: v = cfg.get(self) - if v: + if _exists(v): return transform(v) if transform else v except Exception: ... + return None @@ -273,7 +279,7 @@ def set_if_exists(d: dict, k: str, v: typing.Any) -> dict: The input dictionary ``d`` will be mutated. """ - if v: + if _exists(v): d[k] = v return d diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 4951d4be461..a6c0bb91d85 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,6 +8,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.configuration import Config from flytekit.core import context_manager from flytekit.remote.remote import FlyteRemote @@ -34,6 +35,7 @@ def test_saving_remote(mock_remote): mock_context.obj = {} get_and_save_remote_with_click_context(mock_context, "p", "d") assert mock_context.obj["flyte_remote"] is not None + mock_remote.assert_called_once_with(Config.for_sandbox(), default_project="p", default_domain="d") def test_register_with_no_package_or_module_argument(): diff --git a/tests/flytekit/unit/configuration/configs/good.config b/tests/flytekit/unit/configuration/configs/good.config index 56bb837b009..06c2579d42b 100644 --- a/tests/flytekit/unit/configuration/configs/good.config +++ b/tests/flytekit/unit/configuration/configs/good.config @@ -7,8 +7,8 @@ assumable_iam_role=some_role [platform] - url=fakeflyte.com +insecure=false [madeup] diff --git a/tests/flytekit/unit/configuration/configs/nossl.yaml b/tests/flytekit/unit/configuration/configs/nossl.yaml new file mode 100644 index 00000000000..f7acdde5a58 --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/nossl.yaml @@ -0,0 +1,4 @@ +admin: + endpoint: dns:///flyte.mycorp.io + authType: Pkce + insecure: false diff --git a/tests/flytekit/unit/configuration/test_file.py b/tests/flytekit/unit/configuration/test_file.py index cb10bf42c01..3ce03f9c506 100644 --- a/tests/flytekit/unit/configuration/test_file.py +++ b/tests/flytekit/unit/configuration/test_file.py @@ -7,7 +7,8 @@ from pytimeparse.timeparse import timeparse from flytekit.configuration import ConfigEntry, get_config_file, set_if_exists -from flytekit.configuration.file import LegacyConfigEntry +from flytekit.configuration.file import LegacyConfigEntry, _exists +from flytekit.configuration.internal import Platform def test_set_if_exists(): @@ -21,6 +22,25 @@ def test_set_if_exists(): assert d["k"] == "x" +@pytest.mark.parametrize( + "data, expected", + [ + [1, True], + [1.0, True], + ["foo", True], + [True, True], + [False, True], + [[1], True], + [{"k": "v"}, True], + [None, False], + [[], False], + [{}, False], + ], +) +def test_exists(data, expected): + assert _exists(data) is expected + + def test_get_config_file(): c = get_config_file(None) assert c is None @@ -118,3 +138,9 @@ def test_env_var_bool_transformer(mock_file_read): # The last read should've triggered the file read since now the env var is no longer set. assert mock_file_read.call_count == 1 + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 7f6be53a55a..97e30b5612c 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -77,3 +77,9 @@ def test_some_int(mocked): res = AWS.RETRIES.read(cfg) assert type(res) is int assert res == 5 + + +def test_default_platform_config_endpoint_insecure(): + platform_config = PlatformConfig() + assert platform_config.endpoint == "localhost:30080" + assert platform_config.insecure is False diff --git a/tests/flytekit/unit/configuration/test_yaml_file.py b/tests/flytekit/unit/configuration/test_yaml_file.py index 7e1c3eee98f..ba2c61e1582 100644 --- a/tests/flytekit/unit/configuration/test_yaml_file.py +++ b/tests/flytekit/unit/configuration/test_yaml_file.py @@ -14,6 +14,7 @@ def test_config_entry_file(): assert c.read() is None cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml")) + assert cfg.yaml_config is not None assert c.read(cfg) == "flyte.mycorp.io" c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist @@ -26,6 +27,7 @@ def test_config_entry_file_normal(): cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/no_images.yaml")) images_dict = Images.get_specified_images(cfg) assert images_dict == {} + assert cfg.yaml_config is not None @mock.patch("flytekit.configuration.file.getenv") @@ -43,6 +45,7 @@ def test_config_entry_file_2(mock_get): cfg = get_config_file(sample_yaml_file_name) assert c.read(cfg) == "flyte.mycorp.io" + assert cfg.yaml_config is not None c = ConfigEntry(LegacyConfigEntry("platform", "url2", str)) # Does not exist assert c.read(cfg) is None @@ -67,3 +70,9 @@ def test_real_config(): res = Credentials.SCOPES.read(config_file) assert res == ["all"] + + +def test_use_ssl(): + config_file = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/nossl.yaml")) + res = Platform.INSECURE.read(config_file) + assert res is False diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 01688ea8255..d7fde3215f7 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -231,7 +231,7 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): remote = FlyteRemote( config=Config.auto(config_file=temp_filename), default_project="project", default_domain="domain" ) - assert remote.generate_console_http_domain() == "http://localhost:30080" + assert remote.generate_console_http_domain() == "http://localhost:30081" with open(temp_filename, "w") as f: # This string is similar to the relevant configuration emitted by flytectl in the cases of both demo and sandbox. From 531db3cbadda1e19a42b307a989106db35a836c0 Mon Sep 17 00:00:00 2001 From: bstadlbauer <11799671+bstadlbauer@users.noreply.github.com> Date: Thu, 12 Jan 2023 03:16:50 +0100 Subject: [PATCH 02/13] Add dask plugin #patch (#1366) * Add dummy task type to test backend plugin Signed-off-by: Bernhard Stadlbauer * Add docs page Signed-off-by: Bernhard Stadlbauer * Add dask models Signed-off-by: Bernhard Stadlbauer * Add function to convert resources Signed-off-by: Bernhard Stadlbauer * Add tests to `dask` task Signed-off-by: Bernhard Stadlbauer * Remove namespace Signed-off-by: Bernhard Stadlbauer * Update setup.py Signed-off-by: Bernhard Stadlbauer * Add dask to `plugin/README.md` Signed-off-by: Bernhard Stadlbauer * Add README.md for `dask` Signed-off-by: Bernhard Stadlbauer * Top level export of `JopPodSpec` and `DaskCluster` Signed-off-by: Bernhard Stadlbauer * Update docs for images Signed-off-by: Bernhard Stadlbauer * Update README.md Signed-off-by: Bernhard Stadlbauer * Update models after `flyteidl` change Signed-off-by: Bernhard Stadlbauer * Update task after `flyteidl` change Signed-off-by: Bernhard Stadlbauer * Raise error when less than 1 worker Signed-off-by: Bernhard Stadlbauer * Update flyteidl to >= 1.3.2 Signed-off-by: Bernhard Stadlbauer * Update doc requirements Signed-off-by: Bernhard Stadlbauer * Update doc-requirements.txt Signed-off-by: Bernhard Stadlbauer * Re-lock dependencies on linux Signed-off-by: Bernhard Stadlbauer * Update dask API docs Signed-off-by: Bernhard Stadlbauer * Fix documentation links Signed-off-by: Bernhard Stadlbauer * Default optional model constructor arguments to `None` Signed-off-by: Bernhard Stadlbauer * Refactor `convert_resources_to_resource_model` to `core.resources` Signed-off-by: Bernhard Stadlbauer * Use `convert_resources_to_resource_model` in `core.node` Signed-off-by: Bernhard Stadlbauer * Incorporate review feedback Signed-off-by: Eduardo Apolinario * Lint Signed-off-by: Eduardo Apolinario Signed-off-by: Bernhard Stadlbauer Signed-off-by: Bernhard Stadlbauer Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Co-authored-by: Eduardo Apolinario --- .github/workflows/pythonbuild.yml | 1 + doc-requirements.in | 1 + docs/source/plugins/dask.rst | 12 + docs/source/plugins/index.rst | 2 + flytekit/core/node.py | 16 +- flytekit/core/resources.py | 43 ++- flytekit/core/utils.py | 37 ++- plugins/README.md | 1 + plugins/flytekit-dask/README.md | 21 ++ .../flytekitplugins/dask/__init__.py | 15 ++ .../flytekitplugins/dask/models.py | 134 ++++++++++ .../flytekitplugins/dask/task.py | 108 ++++++++ plugins/flytekit-dask/requirements.in | 2 + plugins/flytekit-dask/requirements.txt | 247 ++++++++++++++++++ plugins/flytekit-dask/setup.py | 42 +++ plugins/flytekit-dask/tests/__init__.py | 0 plugins/flytekit-dask/tests/test_models.py | 96 +++++++ plugins/flytekit-dask/tests/test_task.py | 86 ++++++ tests/flytekit/unit/core/test_resources.py | 68 +++++ 19 files changed, 906 insertions(+), 26 deletions(-) create mode 100644 docs/source/plugins/dask.rst create mode 100644 plugins/flytekit-dask/README.md create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/__init__.py create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/models.py create mode 100644 plugins/flytekit-dask/flytekitplugins/dask/task.py create mode 100644 plugins/flytekit-dask/requirements.in create mode 100644 plugins/flytekit-dask/requirements.txt create mode 100644 plugins/flytekit-dask/setup.py create mode 100644 plugins/flytekit-dask/tests/__init__.py create mode 100644 plugins/flytekit-dask/tests/test_models.py create mode 100644 plugins/flytekit-dask/tests/test_task.py create mode 100644 tests/flytekit/unit/core/test_resources.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 4c16153be9e..cb7cf8aa9e4 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -67,6 +67,7 @@ jobs: - flytekit-aws-batch - flytekit-aws-sagemaker - flytekit-bigquery + - flytekit-dask - flytekit-data-fsspec - flytekit-dbt - flytekit-deck-standard diff --git a/doc-requirements.in b/doc-requirements.in index 17862869cf9..a5b921481c2 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -45,5 +45,6 @@ whylogs # whylogs whylabs-client # whylogs ray # ray scikit-learn # scikit-learn +dask[distributed] # dask vaex # vaex mlflow # mlflow diff --git a/docs/source/plugins/dask.rst b/docs/source/plugins/dask.rst new file mode 100644 index 00000000000..53e9f11fcb9 --- /dev/null +++ b/docs/source/plugins/dask.rst @@ -0,0 +1,12 @@ +.. _dask: + +################################################### +Dask API reference +################################################### + +.. tags:: Integration, DistributedComputing, KubernetesOperator + +.. automodule:: flytekitplugins.dask + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/plugins/index.rst b/docs/source/plugins/index.rst index 008f2b4bbe1..693587192e0 100644 --- a/docs/source/plugins/index.rst +++ b/docs/source/plugins/index.rst @@ -9,6 +9,7 @@ Plugin API reference * :ref:`AWS Sagemaker ` - AWS Sagemaker plugin reference * :ref:`Google Bigquery ` - Google Bigquery plugin reference * :ref:`FS Spec ` - FS Spec API reference +* :ref:`Dask ` - Dask standard API reference * :ref:`Deck standard ` - Deck standard API reference * :ref:`Dolt standard ` - Dolt standard API reference * :ref:`Great expectations ` - Great expectations API reference @@ -40,6 +41,7 @@ Plugin API reference AWS Sagemaker Google Bigquery FS Spec + Dask Deck standard Dolt standard Great expectations diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d8b43f27284..52487e6e482 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -4,7 +4,7 @@ import typing from typing import Any, List -from flytekit.core.resources import Resources +from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model @@ -92,9 +92,14 @@ def with_overrides(self, *args, **kwargs): for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) if "requests" in kwargs or "limits" in kwargs: - requests = _convert_resource_overrides(kwargs.get("requests"), "requests") - limits = _convert_resource_overrides(kwargs.get("limits"), "limits") - self._resources = _resources_model(requests=requests, limits=limits) + requests = kwargs.get("requests") + if requests and not isinstance(requests, Resources): + raise AssertionError("requests should be specified as flytekit.Resources") + limits = kwargs.get("limits") + if limits and not isinstance(limits, Resources): + raise AssertionError("limits should be specified as flytekit.Resources") + + self._resources = convert_resources_to_resource_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: @@ -122,8 +127,7 @@ def _convert_resource_overrides( ) -> [_resources_model.ResourceEntry]: if resources is None: return [] - if not isinstance(resources, Resources): - raise AssertionError(f"{resource_name} should be specified as flytekit.Resources") + resource_entries = [] if resources.cpu is not None: resource_entries.append(_resources_model.ResourceEntry(_resources_model.ResourceName.CPU, resources.cpu)) diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 7b46cbe05c6..62806042466 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional + +from flytekit.models import task as task_models @dataclass @@ -35,3 +37,42 @@ class Resources(object): class ResourceSpec(object): requests: Optional[Resources] = None limits: Optional[Resources] = None + + +_ResouceName = task_models.Resources.ResourceName +_ResourceEntry = task_models.Resources.ResourceEntry + + +def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: + resource_entries = [] + if resources.cpu is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.CPU, value=resources.cpu)) + if resources.mem is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.MEMORY, value=resources.mem)) + if resources.gpu is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.GPU, value=resources.gpu)) + if resources.storage is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.STORAGE, value=resources.storage)) + if resources.ephemeral_storage is not None: + resource_entries.append(_ResourceEntry(name=_ResouceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) + return resource_entries + + +def convert_resources_to_resource_model( + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, +) -> task_models.Resources: + """ + Convert flytekit ``Resources`` objects to a Resources model + + :param requests: Resource requests. Optional, defaults to ``None`` + :param limits: Resource limits. Optional, defaults to ``None`` + :return: The given resources as requests and limits + """ + request_entries = [] + limit_entries = [] + if requests is not None: + request_entries = _convert_resources_to_resource_entries(requests) + if limits is not None: + limit_entries = _convert_resources_to_resource_entries(limits) + return task_models.Resources(requests=request_entries, limits=limit_entries) diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index d23aae3fbbb..ae8b89a1096 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -7,7 +7,7 @@ from typing import Dict, List, Optional from flytekit.loggers import logger -from flytekit.models import task as _task_models +from flytekit.models import task as task_models def _dnsify(value: str) -> str: @@ -52,7 +52,7 @@ def _get_container_definition( image: str, command: List[str], args: List[str], - data_loading_config: Optional[_task_models.DataLoadingConfig] = None, + data_loading_config: Optional[task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, cpu_request: Optional[str] = None, @@ -64,7 +64,7 @@ def _get_container_definition( gpu_limit: Optional[str] = None, memory_limit: Optional[str] = None, environment: Optional[Dict[str, str]] = None, -) -> _task_models.Container: +) -> task_models.Container: storage_limit = storage_limit storage_request = storage_request ephemeral_storage_limit = ephemeral_storage_limit @@ -76,50 +76,49 @@ def _get_container_definition( memory_limit = memory_limit memory_request = memory_request + # TODO: Use convert_resources_to_resource_model instead of manually fixing the resources. requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) + task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_request) ) if ephemeral_storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_request ) ) if cpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: - requests.append( - _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) - ) + requests.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_request)) limits = [] if storage_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.STORAGE, storage_limit)) if ephemeral_storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit + task_models.Resources.ResourceEntry( + task_models.Resources.ResourceName.EPHEMERAL_STORAGE, ephemeral_storage_limit ) ) if cpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: - limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) + limits.append(task_models.Resources.ResourceEntry(task_models.Resources.ResourceName.MEMORY, memory_limit)) if environment is None: environment = {} - return _task_models.Container( + return task_models.Container( image=image, command=command, args=args, - resources=_task_models.Resources(limits=limits, requests=requests), + resources=task_models.Resources(limits=limits, requests=requests), env=environment, config={}, data_loading_config=data_loading_config, diff --git a/plugins/README.md b/plugins/README.md index 447b91a37ca..495ce910196 100644 --- a/plugins/README.md +++ b/plugins/README.md @@ -7,6 +7,7 @@ All the Flytekit plugins maintained by the core team are added here. It is not n | Plugin | Installation | Description | Version | Type | |------------------------------|-----------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------| | AWS Sagemaker Training | ```bash pip install flytekitplugins-awssagemaker ``` | Installs SDK to author Sagemaker built-in and custom training jobs in python | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-awssagemaker/) | Backend | +| dask | ```bash pip install flytekitplugins-dask ``` | Installs SDK to author dask jobs that can be executed natively on Kubernetes using the Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-awssagemaker.svg)](https://pypi.python.org/pypi/flytekitplugins-dask/) | Backend | | Hive Queries | ```bash pip install flytekitplugins-hive ``` | Installs SDK to author Hive Queries that can be executed on a configured hive backend using Flyte backend plugin | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-hive.svg)](https://pypi.python.org/pypi/flytekitplugins-hive/) | Backend | | K8s distributed PyTorch Jobs | ```bash pip install flytekitplugins-kfpytorch ``` | Installs SDK to author Distributed pyTorch Jobs in python using Kubeflow PyTorch Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kfpytorch.svg)](https://pypi.python.org/pypi/flytekitplugins-kfpytorch/) | Backend | | K8s native tensorflow Jobs | ```bash pip install flytekitplugins-kftensorflow ``` | Installs SDK to author Distributed tensorflow Jobs in python using Kubeflow Tensorflow Operator | [![PyPI version fury.io](https://badge.fury.io/py/flytekitplugins-kftensorflow.svg)](https://pypi.python.org/pypi/flytekitplugins-kftensorflow/) | Backend | diff --git a/plugins/flytekit-dask/README.md b/plugins/flytekit-dask/README.md new file mode 100644 index 00000000000..9d645bcd276 --- /dev/null +++ b/plugins/flytekit-dask/README.md @@ -0,0 +1,21 @@ +# Flytekit Dask Plugin + +Flyte can execute `dask` jobs natively on a Kubernetes Cluster, which manages the virtual `dask` cluster's lifecycle +(spin-up and tear down). It leverages the open-source Kubernetes Dask Operator and can be enabled without signing up +for any service. This is like running a transient (ephemeral) `dask` cluster - a type of cluster spun up for a specific +task and torn down after completion. This helps in making sure that the Python environment is the same on the job-runner +(driver), scheduler and the workers. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-dask +``` + +To configure Dask in the Flyte deployment's backed, follow +[step 1](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html#step-1-deploy-the-dask-plugin-in-the-flyte-backend) +and +[step 2](https://docs.flyte.org/projects/cookbook/en/latest/auto/auto/integrations/kubernetes/k8s_dask/index.html#step-2-environment-setup) + +An [example](https://docs.flyte.org/projects/cookbook/en/latest/auto/integrations/kubernetes/k8s_dask/index.html) +can be found in the documentation. diff --git a/plugins/flytekit-dask/flytekitplugins/dask/__init__.py b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py new file mode 100644 index 00000000000..ccadf385fc0 --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/__init__.py @@ -0,0 +1,15 @@ +""" +.. currentmodule:: flytekitplugins.dask + +This package contains the Python related side of the Dask Plugin + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + Dask + Scheduler + WorkerGroup +""" + +from flytekitplugins.dask.task import Dask, Scheduler, WorkerGroup diff --git a/plugins/flytekit-dask/flytekitplugins/dask/models.py b/plugins/flytekit-dask/flytekitplugins/dask/models.py new file mode 100644 index 00000000000..b833ab660a2 --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/models.py @@ -0,0 +1,134 @@ +from typing import Optional + +from flyteidl.plugins import dask_pb2 as dask_task + +from flytekit.models import common as common +from flytekit.models import task as task + + +class Scheduler(common.FlyteIdlEntity): + """ + Configuration for the scheduler pod + + :param image: Optional image to use. + :param resources: Optional resources to use. + """ + + def __init__(self, image: Optional[str] = None, resources: Optional[task.Resources] = None): + self._image = image + self._resources = resources + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image for the scheduler pod + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources for the scheduler pod + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskScheduler: + """ + :return: The scheduler spec serialized to protobuf + """ + return dask_task.DaskScheduler( + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class WorkerGroup(common.FlyteIdlEntity): + """ + Configuration for a dask worker group + + :param number_of_workers:Number of workers in the group + :param image: Optional image to use for the pods of the worker group + :param resources: Optional resources to use for the pods of the worker group + """ + + def __init__( + self, + number_of_workers: int, + image: Optional[str] = None, + resources: Optional[task.Resources] = None, + ): + if number_of_workers < 1: + raise ValueError( + f"Each worker group needs to have at least one worker, but {number_of_workers} have been specified." + ) + + self._number_of_workers = number_of_workers + self._image = image + self._resources = resources + + @property + def number_of_workers(self) -> Optional[int]: + """ + :return: Optional number of workers for the worker group + """ + return self._number_of_workers + + @property + def image(self) -> Optional[str]: + """ + :return: The optional image to use for the worker pods + """ + return self._image + + @property + def resources(self) -> Optional[task.Resources]: + """ + :return: Optional resources to use for the worker pods + """ + return self._resources + + def to_flyte_idl(self) -> dask_task.DaskWorkerGroup: + """ + :return: The dask cluster serialized to protobuf + """ + return dask_task.DaskWorkerGroup( + number_of_workers=self.number_of_workers, + image=self.image, + resources=self.resources.to_flyte_idl() if self.resources else None, + ) + + +class DaskJob(common.FlyteIdlEntity): + """ + Configuration for the custom dask job to run + + :param scheduler: Configuration for the scheduler + :param workers: Configuration of the default worker group + """ + + def __init__(self, scheduler: Scheduler, workers: WorkerGroup): + self._scheduler = scheduler + self._workers = workers + + @property + def scheduler(self) -> Scheduler: + """ + :return: Configuration for the scheduler pod + """ + return self._scheduler + + @property + def workers(self) -> WorkerGroup: + """ + :return: Configuration of the default worker group + """ + return self._workers + + def to_flyte_idl(self) -> dask_task.DaskJob: + """ + :return: The dask job serialized to protobuf + """ + return dask_task.DaskJob( + scheduler=self.scheduler.to_flyte_idl(), + workers=self.workers.to_flyte_idl(), + ) diff --git a/plugins/flytekit-dask/flytekitplugins/dask/task.py b/plugins/flytekit-dask/flytekitplugins/dask/task.py new file mode 100644 index 00000000000..830ede98ef8 --- /dev/null +++ b/plugins/flytekit-dask/flytekitplugins/dask/task.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from flytekitplugins.dask import models +from google.protobuf.json_format import MessageToDict + +from flytekit import PythonFunctionTask, Resources +from flytekit.configuration import SerializationSettings +from flytekit.core.resources import convert_resources_to_resource_model +from flytekit.core.task import TaskPlugins + + +@dataclass +class Scheduler: + """ + Configuration for the scheduler pod + + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed and should have the same Python + environment as the rest of the cluster (job runner pod + worker pods). + :param requests: Resources to request for the scheduler pod. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the scheduler pod. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class WorkerGroup: + """ + Configuration for a group of dask worker pods + + :param number_of_workers: Number of workers to use. Optional, defaults to 1. + :param image: Custom image to use. If ``None``, will use the same image the task was registered with. Optional, + defaults to ``None``. The image must have ``dask[distributed]`` installed. The provided image should have the + same Python environment as the job runner/driver as well as the scheduler. + :param requests: Resources to request for the worker pods. If ``None``, the requests passed into the task will be + used. Optional, defaults to ``None``. + :param limits: Resource limits for the worker pods. If ``None``, the limits passed into the task will be used. + Optional, defaults to ``None``. + """ + + number_of_workers: Optional[int] = 1 + image: Optional[str] = None + requests: Optional[Resources] = None + limits: Optional[Resources] = None + + +@dataclass +class Dask: + """ + Configuration for the dask task + + :param scheduler: Configuration for the scheduler pod. Optional, defaults to ``Scheduler()``. + :param workers: Configuration for the pods of the default worker group. Optional, defaults to ``WorkerGroup()``. + """ + + scheduler: Scheduler = Scheduler() + workers: WorkerGroup = WorkerGroup() + + +class DaskTask(PythonFunctionTask[Dask]): + """ + Actual Plugin that transforms the local python code for execution within a dask cluster + """ + + _DASK_TASK_TYPE = "dask" + + def __init__(self, task_config: Dask, task_function: Callable, **kwargs): + super(DaskTask, self).__init__( + task_config=task_config, + task_type=self._DASK_TASK_TYPE, + task_function=task_function, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + """ + Serialize the `dask` task config into a dict. + + :param settings: Current serialization settings + :return: Dictionary representation of the dask task config. + """ + scheduler = models.Scheduler( + image=self.task_config.scheduler.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.scheduler.requests, + limits=self.task_config.scheduler.limits, + ), + ) + workers = models.WorkerGroup( + number_of_workers=self.task_config.workers.number_of_workers, + image=self.task_config.workers.image, + resources=convert_resources_to_resource_model( + requests=self.task_config.workers.requests, + limits=self.task_config.workers.limits, + ), + ) + job = models.DaskJob(scheduler=scheduler, workers=workers) + return MessageToDict(job.to_flyte_idl()) + + +# Inject the `dask` plugin into flytekits dynamic plugin loading system +TaskPlugins.register_pythontask_plugin(Dask, DaskTask) diff --git a/plugins/flytekit-dask/requirements.in b/plugins/flytekit-dask/requirements.in new file mode 100644 index 00000000000..310ade8617d --- /dev/null +++ b/plugins/flytekit-dask/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-dask diff --git a/plugins/flytekit-dask/requirements.txt b/plugins/flytekit-dask/requirements.txt new file mode 100644 index 00000000000..2ec017e46dc --- /dev/null +++ b/plugins/flytekit-dask/requirements.txt @@ -0,0 +1,247 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile --output-file=requirements.txt requirements.in setup.py +# +-e file:.#egg=flytekitplugins-dask + # via -r requirements.in +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +certifi==2022.9.24 + # via requests +cffi==1.15.1 + # via cryptography +chardet==5.0.0 + # via binaryornot +charset-normalizer==2.1.1 + # via requests +click==8.1.3 + # via + # cookiecutter + # dask + # distributed + # flytekit +cloudpickle==2.2.0 + # via + # dask + # distributed + # flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.7 + # via flytekit +cryptography==38.0.3 + # via + # pyopenssl + # secretstorage +dask[distributed]==2022.10.2 + # via + # distributed + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via retry +deprecated==1.2.13 + # via flytekit +diskcache==5.4.0 + # via flytekit +distributed==2022.10.2 + # via dask +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.3.2 + # via + # flytekit + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +flytekit==1.3.0b2 + # via + # flytekitplugins-dask + # flytekitplugins-dask (setup.py) +fsspec==2022.10.0 + # via dask +googleapis-common-protos==1.56.4 + # via + # flyteidl + # grpcio-status +grpcio==1.51.1 + # via + # flytekit + # grpcio-status +grpcio-status==1.51.1 + # via flytekit +heapdict==1.0.1 + # via zict +idna==3.4 + # via requests +importlib-metadata==5.0.0 + # via + # flytekit + # keyring +jaraco-classes==3.2.3 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 + # via + # cookiecutter + # distributed + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via flytekit +keyring==23.11.0 + # via flytekit +locket==1.0.0 + # via + # distributed + # partd +markupsafe==2.1.1 + # via jinja2 +marshmallow==3.18.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.0.0 + # via jaraco-classes +msgpack==1.0.4 + # via distributed +mypy-extensions==0.4.3 + # via typing-inspect +natsort==8.2.0 + # via flytekit +numpy==1.23.4 + # via + # pandas + # pyarrow +packaging==21.3 + # via + # dask + # distributed + # docker + # marshmallow +pandas==1.5.1 + # via flytekit +partd==1.3.0 + # via dask +protobuf==4.21.11 + # via + # flyteidl + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +psutil==5.9.3 + # via distributed +py==1.11.0 + # via retry +pyarrow==6.0.1 + # via flytekit +pycparser==2.21 + # via cffi +pyopenssl==22.1.0 + # via flytekit +pyparsing==3.0.9 + # via packaging +python-dateutil==2.8.2 + # via + # arrow + # croniter + # flytekit + # pandas +python-json-logger==2.0.4 + # via flytekit +python-slugify==6.1.2 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2022.6 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # dask + # distributed + # flytekit +regex==2022.10.31 + # via docker-image-py +requests==2.28.1 + # via + # cookiecutter + # docker + # flytekit + # responses +responses==0.22.0 + # via flytekit +retry==0.9.2 + # via flytekit +secretstorage==3.3.3 + # via keyring +six==1.16.0 + # via python-dateutil +sortedcontainers==2.4.0 + # via + # distributed + # flytekit +statsd==3.3.0 + # via flytekit +tblib==1.7.0 + # via distributed +text-unidecode==1.3 + # via python-slugify +toml==0.10.2 + # via responses +toolz==0.12.0 + # via + # dask + # distributed + # partd +tornado==6.1 + # via distributed +types-toml==0.10.8 + # via responses +typing-extensions==4.4.0 + # via + # flytekit + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.12 + # via + # distributed + # docker + # flytekit + # requests + # responses +websocket-client==1.4.2 + # via docker +wheel==0.38.2 + # via flytekit +wrapt==1.14.1 + # via + # deprecated + # flytekit +zict==2.2.0 + # via distributed +zipp==3.10.0 + # via importlib-metadata diff --git a/plugins/flytekit-dask/setup.py b/plugins/flytekit-dask/setup.py new file mode 100644 index 00000000000..440d7b47db5 --- /dev/null +++ b/plugins/flytekit-dask/setup.py @@ -0,0 +1,42 @@ +from setuptools import setup + +PLUGIN_NAME = "dask" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = [ + "flyteidl>=1.3.2", + "flytekit>=1.3.0b2,<2.0.0", + "dask[distributed]>=2022.10.2", +] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Dask plugin for flytekit", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-dask", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", # dask requires >= 3.8 + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-dask/tests/__init__.py b/plugins/flytekit-dask/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/plugins/flytekit-dask/tests/test_models.py b/plugins/flytekit-dask/tests/test_models.py new file mode 100644 index 00000000000..801a110fb13 --- /dev/null +++ b/plugins/flytekit-dask/tests/test_models.py @@ -0,0 +1,96 @@ +import pytest +from flytekitplugins.dask import models + +from flytekit.models import task as _task + + +@pytest.fixture +def image() -> str: + return "foo:latest" + + +@pytest.fixture +def resources() -> _task.Resources: + return _task.Resources( + requests=[ + _task.Resources.ResourceEntry(name=_task.Resources.ResourceName.CPU, value="3"), + ], + limits=[], + ) + + +@pytest.fixture +def default_resources() -> _task.Resources: + return _task.Resources(requests=[], limits=[]) + + +@pytest.fixture +def scheduler(image: str, resources: _task.Resources) -> models.Scheduler: + return models.Scheduler(image=image, resources=resources) + + +@pytest.fixture +def workers(image: str, resources: _task.Resources) -> models.WorkerGroup: + return models.WorkerGroup(number_of_workers=123, image=image, resources=resources) + + +def test_create_scheduler_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_create_scheduler_to_flyte_idl_all_optional(default_resources: _task.Resources): + scheduler = models.Scheduler(image=None, resources=None) + idl_object = scheduler.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_create_scheduler_spec_property_access(image: str, resources: _task.Resources): + scheduler = models.Scheduler(image=image, resources=resources) + assert scheduler.image == image + assert scheduler.resources == resources + + +def test_worker_group_to_flyte_idl_no_optional(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + idl_object = worker_group.to_flyte_idl() + assert idl_object.number_of_workers == n_workers + assert idl_object.image == image + assert idl_object.resources == resources.to_flyte_idl() + + +def test_worker_group_to_flyte_idl_all_optional(default_resources: _task.Resources): + worker_group = models.WorkerGroup(number_of_workers=1, image=None, resources=None) + idl_object = worker_group.to_flyte_idl() + assert idl_object.image == "" + assert idl_object.resources == default_resources.to_flyte_idl() + + +def test_worker_group_property_access(image: str, resources: _task.Resources): + n_workers = 1234 + worker_group = models.WorkerGroup(number_of_workers=n_workers, image=image, resources=resources) + assert worker_group.image == image + assert worker_group.number_of_workers == n_workers + assert worker_group.resources == resources + + +def test_worker_group_fails_for_less_than_one_worker(): + with pytest.raises(ValueError, match=r"Each worker group needs to"): + models.WorkerGroup(number_of_workers=0, image=None, resources=None) + + +def test_dask_job_to_flyte_idl_no_optional(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + idl_object = job.to_flyte_idl() + assert idl_object.scheduler == scheduler.to_flyte_idl() + assert idl_object.workers == workers.to_flyte_idl() + + +def test_dask_job_property_access(scheduler: models.Scheduler, workers: models.WorkerGroup): + job = models.DaskJob(scheduler=scheduler, workers=workers) + assert job.scheduler == scheduler + assert job.workers == workers diff --git a/plugins/flytekit-dask/tests/test_task.py b/plugins/flytekit-dask/tests/test_task.py new file mode 100644 index 00000000000..76dbf9d048e --- /dev/null +++ b/plugins/flytekit-dask/tests/test_task.py @@ -0,0 +1,86 @@ +import pytest +from flytekitplugins.dask import Dask, Scheduler, WorkerGroup + +from flytekit import PythonFunctionTask, Resources, task +from flytekit.configuration import Image, ImageConfig, SerializationSettings + + +@pytest.fixture +def serialization_settings() -> SerializationSettings: + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + return settings + + +def test_dask_task_with_default_config(serialization_settings: SerializationSettings): + task_config = Dask() + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + assert dask_task.task_config == task_config + assert dask_task.task_type == "dask" + + expected_dict = { + "scheduler": { + "resources": {}, + }, + "workers": { + "numberOfWorkers": 1, + "resources": {}, + }, + } + assert dask_task.get_custom(serialization_settings) == expected_dict + + +def test_dask_task_get_custom(serialization_settings: SerializationSettings): + task_config = Dask( + scheduler=Scheduler( + image="scheduler:latest", + requests=Resources(cpu="1"), + limits=Resources(cpu="2"), + ), + workers=WorkerGroup( + number_of_workers=123, + image="dask_cluster:latest", + requests=Resources(cpu="3"), + limits=Resources(cpu="4"), + ), + ) + + @task(task_config=task_config) + def dask_task(): + pass + + # Helping type completion in PyCharm + dask_task: PythonFunctionTask[Dask] + + expected_custom_dict = { + "scheduler": { + "image": "scheduler:latest", + "resources": { + "requests": [{"name": "CPU", "value": "1"}], + "limits": [{"name": "CPU", "value": "2"}], + }, + }, + "workers": { + "numberOfWorkers": 123, + "image": "dask_cluster:latest", + "resources": { + "requests": [{"name": "CPU", "value": "3"}], + "limits": [{"name": "CPU", "value": "4"}], + }, + }, + } + custom_dict = dask_task.get_custom(serialization_settings) + assert custom_dict == expected_custom_dict diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py new file mode 100644 index 00000000000..1a3bf64deed --- /dev/null +++ b/tests/flytekit/unit/core/test_resources.py @@ -0,0 +1,68 @@ +from typing import Dict + +import pytest + +import flytekit.models.task as _task_models +from flytekit import Resources +from flytekit.core.resources import convert_resources_to_resource_model + +_ResourceName = _task_models.Resources.ResourceName + + +def test_convert_no_requests_no_limits(): + resource_model = convert_resources_to_resource_model(requests=None, limits=None) + assert isinstance(resource_model, _task_models.Resources) + assert resource_model.requests == [] + assert resource_model.limits == [] + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_requests(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(requests=requests) + + assert len(resources_model.requests) == 1 + request = resources_model.requests[0] + assert isinstance(request, _task_models.Resources.ResourceEntry) + assert request.name == expected_resource_name + assert request.value == expected_resource_value + assert len(resources_model.limits) == 0 + + +@pytest.mark.parametrize( + argnames=("resource_dict", "expected_resource_name"), + argvalues=( + ({"cpu": "2"}, _ResourceName.CPU), + ({"mem": "1Gi"}, _ResourceName.MEMORY), + ({"gpu": "1"}, _ResourceName.GPU), + ({"storage": "100Mb"}, _ResourceName.STORAGE), + ({"ephemeral_storage": "123Mb"}, _ResourceName.EPHEMERAL_STORAGE), + ), + ids=("CPU", "MEMORY", "GPU", "STORAGE", "EPHEMERAL_STORAGE"), +) +def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _task_models.Resources): + assert len(resource_dict) == 1 + expected_resource_value = list(resource_dict.values())[0] + + requests = Resources(**resource_dict) + resources_model = convert_resources_to_resource_model(limits=requests) + + assert len(resources_model.limits) == 1 + limit = resources_model.limits[0] + assert isinstance(limit, _task_models.Resources.ResourceEntry) + assert limit.name == expected_resource_name + assert limit.value == expected_resource_value + assert len(resources_model.requests) == 0 From 4b1675ffb85648dc5742e9a6dea98b94714963e1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 14 Jan 2023 06:27:12 +0800 Subject: [PATCH 03/13] Add support for overriding task configurations (#1410) Signed-off-by: Kevin Su --- flytekit/core/node.py | 7 ++++++ .../flytekit/unit/core/test_node_creation.py | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 52487e6e482..220301c4020 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -6,6 +6,7 @@ from flytekit.core.resources import Resources, convert_resources_to_resource_model from flytekit.core.utils import _dnsify +from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.task import Resources as _resources_model @@ -119,6 +120,12 @@ def with_overrides(self, *args, **kwargs): self._metadata._interruptible = kwargs["interruptible"] if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: + logger.warning("This override is beta. We may want to revisit this in the future.") + new_task_config = kwargs["task_config"] + if not isinstance(new_task_config, type(self.flyte_entity._task_config)): + raise ValueError("can't change the type of the task config") + self.flyte_entity._task_config = new_task_config return self diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 47c8af98307..2813563fb94 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -1,6 +1,7 @@ import datetime import typing from collections import OrderedDict +from dataclasses import dataclass import pytest @@ -424,3 +425,25 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.name == "foo" + + +def test_config_override(): + @dataclass + class DummyConfig: + name: str + + @task(task_config=DummyConfig(name="hello")) + def t1(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=DummyConfig("flyte")) + + assert my_wf.nodes[0].flyte_entity.task_config.name == "flyte" + + with pytest.raises(ValueError): + + @workflow + def my_wf(a: str) -> str: + return t1(a=a).with_overrides(task_config=None) From 905cabcd8f35556e56b4a519d87eafefdd9bddf0 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 18 Jan 2023 16:08:32 +0800 Subject: [PATCH 04/13] Warning if git is not installed (#1414) * warning if git is not installed Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/remote/remote.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 23c9803b07a..d8263a88124 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -19,7 +19,6 @@ from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 -from git import Repo from flytekit import Literal from flytekit.clients.friendly import SynchronousFlyteClient @@ -127,9 +126,14 @@ def _get_git_repo_url(source_path): Get git repo URL from remote.origin.url """ try: + from git import Repo + return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1] + except ImportError: + remote_logger.warning("Could not import git. is the git executable installed?") except Exception: # If the file isn't in the git repo, we can't get the url from git config + remote_logger.debug(f"{source_path} is not a git repo.") return "" From 5dd887cf5cbe817a35d682a32886f304a97fc910 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 19 Jan 2023 13:08:02 -0800 Subject: [PATCH 05/13] Flip the settings for channel and logger (#1415) Signed-off-by: Yee Hing Tong --- flytekit/loggers.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/flytekit/loggers.py b/flytekit/loggers.py index 0c8c2e035a5..f047348de07 100644 --- a/flytekit/loggers.py +++ b/flytekit/loggers.py @@ -13,12 +13,6 @@ # By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning logger = logging.getLogger("flytekit") -# Root logger control -flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" -if os.getenv(flytekit_root_env_var) is not None: - logger.setLevel(int(os.getenv(flytekit_root_env_var))) -else: - logger.setLevel(logging.DEBUG) # Stop propagation so that configuration is isolated to this file (so that it doesn't matter what the # global Python root logger is set to). @@ -40,22 +34,33 @@ # create console handler ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +# Root logger control # Don't want to import the configuration library since that will cause all sorts of circular imports, let's # just use the environment variable if it's defined. Decide in the future when we implement better controls # if we should control with the channel or with the logger level. # The handler log level controls whether log statements will actually print to the screen +flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT" level_from_env = os.getenv(LOGGING_ENV_VAR) -if level_from_env is not None: - ch.setLevel(int(level_from_env)) +root_level_from_env = os.getenv(flytekit_root_env_var) +if root_level_from_env is not None: + logger.setLevel(int(root_level_from_env)) +elif level_from_env is not None: + logger.setLevel(int(level_from_env)) else: - ch.setLevel(logging.WARNING) + logger.setLevel(logging.WARNING) for log_name, child_logger in child_loggers.items(): env_var = f"{LOGGING_ENV_VAR}_{log_name.upper()}" level_from_env = os.getenv(env_var) if level_from_env is not None: child_logger.setLevel(int(level_from_env)) + else: + if child_logger is user_space_logger: + child_logger.setLevel(logging.INFO) + else: + child_logger.setLevel(logging.WARNING) # create formatter formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s") From 310aea3296170126433f3e7b2949c6a6ea1203c0 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Fri, 20 Jan 2023 11:48:30 -0800 Subject: [PATCH 06/13] Preserving Exception in the LazyEntity fetch (#1412) * Preserving Exception in the LazyEntity fetch Signed-off-by: Ketan Umare * updated lint error Signed-off-by: Ketan Umare * more tests Signed-off-by: Ketan Umare Signed-off-by: Ketan Umare --- flytekit/core/promise.py | 3 ++- flytekit/remote/lazy_entity.py | 7 ++++++- tests/flytekit/unit/remote/test_lazy_entity.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 53048cb03f7..bef86cc9ed1 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -854,7 +854,8 @@ def create_and_link_node_from_remote( extra_inputs = used_inputs ^ set(kwargs.keys()) if len(extra_inputs) > 0: raise _user_exceptions.FlyteAssertion( - "Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs) + f"Too many inputs for [{entity.name}] Expected inputs: {typed_interface.inputs.keys()} " + f"- extra inputs: {extra_inputs}" ) # Detect upstream nodes diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py index b40c6e3ff72..4755aad99da 100644 --- a/flytekit/remote/lazy_entity.py +++ b/flytekit/remote/lazy_entity.py @@ -37,7 +37,12 @@ def entity(self) -> T: """ with self._mutex: if self._entity is None: - self._entity = self._getter() + try: + self._entity = self._getter() + except AttributeError as e: + raise RuntimeError( + f"Error downloading the entity {self._name}, (check original exception...)" + ) from e return self._entity def __getattr__(self, item: str) -> typing.Any: diff --git a/tests/flytekit/unit/remote/test_lazy_entity.py b/tests/flytekit/unit/remote/test_lazy_entity.py index 1ed191aea41..5328a2caf0f 100644 --- a/tests/flytekit/unit/remote/test_lazy_entity.py +++ b/tests/flytekit/unit/remote/test_lazy_entity.py @@ -63,3 +63,16 @@ def _getter(): e.compile(ctx) assert e._entity is not None assert e.entity == dummy_task + + +def test_lazy_loading_exception(): + def _getter(): + raise AttributeError("Error") + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + with pytest.raises(RuntimeError) as exc: + assert e.blah + + assert isinstance(exc.value.__cause__, AttributeError) From b6605bc95bd6f03434aeb5ca8f69eba161924ef8 Mon Sep 17 00:00:00 2001 From: Peeter Piegaze Date: Sat, 21 Jan 2023 00:59:36 +0100 Subject: [PATCH 07/13] [Docs] SynchronousFlyteClient API reference #3095 (#1416) Signed-off-by: Peeter Piegaze Signed-off-by: Peeter Piegaze Co-authored-by: Peeter Piegaze Co-authored-by: Haytham Abuelfutuh --- docs/source/clients.rst | 4 ++++ docs/source/index.rst | 1 + flytekit/clients/__init__.py | 19 +++++++++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 docs/source/clients.rst diff --git a/docs/source/clients.rst b/docs/source/clients.rst new file mode 100644 index 00000000000..f67ebf6a3af --- /dev/null +++ b/docs/source/clients.rst @@ -0,0 +1,4 @@ +.. automodule:: flytekit.clients + :no-members: + :no-inherited-members: + :no-special-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index b0d46866fac..db5902391be 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -76,6 +76,7 @@ Expected output: flytekit configuration remote + clients testing extend deck diff --git a/flytekit/clients/__init__.py b/flytekit/clients/__init__.py index e69de29bb2d..1b08e1c5670 100644 --- a/flytekit/clients/__init__.py +++ b/flytekit/clients/__init__.py @@ -0,0 +1,19 @@ +""" +===================== +Clients +===================== + +.. currentmodule:: flytekit.clients + +This module provides lower level access to a Flyte backend. + +.. _clients_module: + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + :nosignatures: + + ~friendly.SynchronousFlyteClient + ~raw.RawSynchronousFlyteClient +""" From 99722d59349a574a73e4b77e77bfa54ceab5c012 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Jan 2023 02:11:26 +0800 Subject: [PATCH 08/13] Return error code on fail (#1408) * AWS batch return error code once it fails Signed-off-by: Kevin Su * AWS batch return error code once it fails Signed-off-by: Kevin Su * update tests Signed-off-by: Kevin Su * Update tests Signed-off-by: Kevin Su Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/bin/entrypoint.py | 6 ++++ .../unit/bin/test_python_entrypoint.py | 32 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 3d5017675ef..fe38f946f9b 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -161,6 +161,12 @@ def _dispatch_execute( logger.info(f"Engine folder written successfully to the output prefix {output_prefix}") logger.debug("Finished _dispatch_execute") + if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict: + # This env is set by the flytepropeller + # AWS batch job get the status from the exit code, so once we catch the error, + # we should return the error code here + exit(1) + def get_one_of(*args) -> str: """ diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 479ad9e7bd3..6a8b8c430e2 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -3,6 +3,7 @@ from collections import OrderedDict import mock +import pytest from flyteidl.core.errors_pb2 import ErrorDocument from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution @@ -110,6 +111,37 @@ def verify_output(*args, **kwargs): assert mock_write_to_file.call_count == 1 +@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"}) +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto): + mock_get_data.return_value = True + mock_upload_dir.return_value = True + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + python_task = mock.MagicMock() + python_task.dispatch_execute.side_effect = Exception("random") + + empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl() + mock_load_proto.return_value = empty_literal_map + + def verify_output(*args, **kwargs): + assert isinstance(args[0], ErrorDocument) + + mock_write_to_file.side_effect = verify_output + + with pytest.raises(SystemExit) as cm: + _dispatch_execute(ctx, python_task, "inputs path", "outputs prefix") + pytest.assertEqual(cm.value.code, 1) + + # This function collects outputs instead of writing them to a file. # See flytekit.core.utils.write_proto_to_file for the original def get_output_collector(results: OrderedDict): From 5126b2aa86c3b2b3154eeb782e71c97cde95ac9a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 24 Jan 2023 13:43:39 -0800 Subject: [PATCH 09/13] wrapping flyte entity in a task node in call to flyte node constructor, not sure if integration tests are actually running (#1422) Signed-off-by: Yee Hing Tong Signed-off-by: Yee Hing Tong --- flytekit/remote/remote.py | 4 ++-- tests/flytekit/integration/remote/test_remote.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index d8263a88124..0a82d1fb65b 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -55,7 +55,7 @@ NotificationList, WorkflowExecutionGetDataResponse, ) -from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteWorkflow +from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface from flytekit.remote.lazy_entity import LazyEntity @@ -1460,7 +1460,7 @@ def sync_execution( upstream_nodes=[], bindings=[], metadata=NodeMetadata(name=""), - flyte_task=flyte_entity, + task_node=FlyteTaskNode(flyte_entity), ) } if len(task_node_exec) >= 1 diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index dd021eb3be3..09b794775bc 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -221,6 +221,7 @@ def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register): flyte_task = remote.fetch_task(name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}") d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} execution = remote.execute(flyte_task, {"d": d}, wait=True) + remote.sync_execution(execution, sync_nodes=True) assert json.loads(execution.outputs["o0"]) == {"key1": "value1", "key2": "value2"} From 7b48fff5cfe2fcf19e28fde8e35c20acdb9059e2 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Wed, 25 Jan 2023 15:24:22 -0500 Subject: [PATCH 10/13] Sqlalchemy multiline query (#1421) * SQLAlchemyTask should handle multiline strings for query template Signed-off-by: Niels Bantilan * sqlalchemy supports multi-line query Signed-off-by: Niels Bantilan * update base sql task Signed-off-by: Niels Bantilan * remove space Signed-off-by: Niels Bantilan * fix snowflake tests Signed-off-by: Niels Bantilan * fix lint Signed-off-by: Niels Bantilan * fix test Signed-off-by: Niels Bantilan Signed-off-by: Niels Bantilan --- flytekit/core/base_sql_task.py | 2 +- flytekit/extras/sqlite3/task.py | 5 ++--- .../tests/test_snowflake.py | 4 ++-- .../flytekit-sqlalchemy/tests/test_task.py | 20 +++++++++++++++++-- .../flytekit/unit/extras/sqlite3/test_task.py | 4 ++-- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index d2e4838ed84..7fcdc15a50f 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -41,7 +41,7 @@ def __init__( task_config=task_config, **kwargs, ) - self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t") + self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() @property def query_template(self) -> str: diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 0284440da39..8e7d8b3b29d 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -92,14 +92,13 @@ def __init__( container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, + # Sanitize query by removing the newlines at the end of the query. Keep in mind + # that the query can be a multiline string. query_template=query_template, inputs=inputs, outputs=outputs, **kwargs, ) - # Sanitize query by removing the newlines at the end of the query. Keep in mind - # that the query can be a multiline string. - self._query_template = query_template.replace("\n", " ") @property def output_columns(self) -> typing.Optional[typing.List[str]]: diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index a012e38d99b..672f4a19ad0 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -70,7 +70,7 @@ def test_local_exec(): ) assert len(snowflake_task.interface.inputs) == 1 - assert snowflake_task.query_template == "select 1\\n" + assert snowflake_task.query_template == "select 1" assert len(snowflake_task.interface.outputs) == 1 # will not run locally @@ -86,4 +86,4 @@ def test_sql_template(): custom where column = 1""", output_schema_type=FlyteSchema, ) - assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1" + assert snowflake_task.query_template == "select 1 from custom where column = 1" diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 6d20027b2a7..7537a3a1de5 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -70,7 +70,23 @@ def test_task_schema(sql_server): assert df is not None -def test_workflow(sql_server): +@pytest.mark.parametrize( + "query_template", + [ + "select * from tracks limit {{.inputs.limit}}", + """ + select * from tracks + limit {{.inputs.limit}} + """, + """select * from tracks + limit {{.inputs.limit}} + """, + """ + select * from tracks + limit {{.inputs.limit}}""", + ], +) +def test_workflow(sql_server, query_template): @task def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) @@ -84,7 +100,7 @@ def my_task(df: pandas.DataFrame) -> int: sql_task = SQLAlchemyTask( "test", - query_template="select * from tracks limit {{.inputs.limit}}", + query_template=query_template, inputs=kwtypes(limit=int), task_config=SQLAlchemyConfig(uri=sql_server), ) diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index ef7ea491e64..40fc94a3d2f 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -119,14 +119,14 @@ def test_task_serialization(): select * from tracks limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ( """ \ select * \ from tracks \ limit {{.inputs.limit}}""", - " select * from tracks limit {{.inputs.limit}}", + "select * from tracks limit {{.inputs.limit}}", ), ("select * from abc", "select * from abc"), ], From 1513773eed118fc3ddacffddf71ee3d30cd0c5cf Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Fri, 27 Jan 2023 13:04:26 -0500 Subject: [PATCH 11/13] Sklearn type transformer should be automatically loaded with import flytekit (#1423) * add flytekit.extras.sklearn to main __init__ import Signed-off-by: Niels Bantilan * fix docs Signed-off-by: Niels Bantilan * add temporary docs/requirements.txt for onnx plugins Signed-off-by: Niels Bantilan --------- Signed-off-by: Niels Bantilan --- .readthedocs.yml | 1 + docs/requirements.txt | 7 +++++++ flytekit/__init__.py | 2 +- flytekit/extras/pytorch/__init__.py | 5 +++-- flytekit/extras/sklearn/__init__.py | 4 ++-- flytekit/extras/tensorflow/__init__.py | 11 +++++++++++ 6 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 docs/requirements.txt diff --git a/.readthedocs.yml b/.readthedocs.yml index 86a85609d7a..19b1898e947 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -18,3 +18,4 @@ sphinx: python: install: - requirements: doc-requirements.txt + - requirements: docs/requirements.txt diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000000..1fb1b913592 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,7 @@ +# TODO: Remove after buf migration is done and packages updated, see doc-requirements.in +# skl2onnx and tf2onnx added here so that the plugin API reference is rendered, +# with the caveat that the docs build environment has the environment variable +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set so that protobuf can be parsed +# using Python, which is acceptable for docs building. +skl2onnx +tf2onnx diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 5ba43b78dc9..e028cbaab9d 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -184,7 +184,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extras import pytorch, tensorflow +from flytekit.extras import pytorch, sklearn, tensorflow from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index 770fe11b735..a29d8e89e6d 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -1,6 +1,4 @@ """ -Flytekit PyTorch -========================================= .. currentmodule:: flytekit.extras.pytorch .. autosummary:: @@ -8,6 +6,9 @@ :toctree: generated/ PyTorchCheckpoint + PyTorchCheckpointTransformer + PyTorchModuleTransformer + PyTorchTensorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/sklearn/__init__.py b/flytekit/extras/sklearn/__init__.py index 0a1bf2dda54..1d16f6080f9 100644 --- a/flytekit/extras/sklearn/__init__.py +++ b/flytekit/extras/sklearn/__init__.py @@ -1,11 +1,11 @@ """ -Flytekit Sklearn -========================================= .. currentmodule:: flytekit.extras.sklearn .. autosummary:: :template: custom.rst :toctree: generated/ + + SklearnEstimatorTransformer """ from flytekit.loggers import logger diff --git a/flytekit/extras/tensorflow/__init__.py b/flytekit/extras/tensorflow/__init__.py index 4db9b428ea5..fe10c9024b3 100644 --- a/flytekit/extras/tensorflow/__init__.py +++ b/flytekit/extras/tensorflow/__init__.py @@ -1,3 +1,14 @@ +""" +.. currentmodule:: flytekit.extras.tensorflow + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + TensorFlowRecordFileTransformer + TensorFlowRecordsDirTransformer +""" + from flytekit.loggers import logger # TODO: abstract this out so that there's an established pattern for registering plugins From 88c6bba6c3b114e9f80fca1768cafd6920935fac Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Tue, 31 Jan 2023 10:50:49 -0800 Subject: [PATCH 12/13] Bump isort to 5.12.0 (#1427) Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39470b73703..1fd6e6b648f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: hooks: - id: black - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] From ea39054223a9432692c03314b90c45240f2d8419 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Tue, 31 Jan 2023 13:17:08 -0800 Subject: [PATCH 13/13] Fixes guess type bug in UnionTransformer (#1426) Signed-off-by: Ketan Umare Co-authored-by: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> --- flytekit/core/type_engine.py | 2 +- tests/flytekit/unit/core/test_type_engine.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 6ddeb5c58cc..61c448b365c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1138,7 +1138,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v.type) for v in literal_type.union_type.variants)] # type: ignore + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] raise ValueError(f"Union transformer cannot reverse {literal_type}") diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index bbe46845fd3..eb38a8d80be 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -45,7 +45,7 @@ from flytekit.models.annotation import TypeAnnotation from flytekit.models.core.types import BlobType from flytekit.models.literals import Blob, BlobMetadata, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Void -from flytekit.models.types import LiteralType, SimpleType, TypeStructure +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.types.directory import TensorboardLogs from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FileExt, JPEGImageFile @@ -941,6 +941,18 @@ def test_union_transformer(): assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int +def test_union_guess_type(): + ut = UnionTransformer() + t = ut.guess_python_type( + LiteralType( + union_type=UnionType( + variants=[LiteralType(simple=SimpleType.STRING), LiteralType(simple=SimpleType.INTEGER)] + ) + ) + ) + assert t == typing.Union[str, int] + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})]