From f3d0a54a0c537240dd6d6503d0fabbc775cddad1 Mon Sep 17 00:00:00 2001 From: Calvin Leather Date: Thu, 4 Aug 2022 14:28:05 -0400 Subject: [PATCH 01/27] Add deck to papermill plugin task (#1111) Signed-off-by: Calvin Leather --- .../flytekitplugins/papermill/task.py | 16 +++++++++++++++- plugins/flytekit-papermill/tests/test_task.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index a58b01d482..304932a828 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -11,6 +11,7 @@ from flytekit import FlyteContext, PythonInstanceTask from flytekit.core.context_manager import ExecutionParameters +from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger from flytekit.models.literals import LiteralMap @@ -63,6 +64,7 @@ class NotebookTask(PythonInstanceTask[T]): name="modulename.my_notebook_task", # the name should be unique within all your tasks, usually it is a good # idea to use the modulename notebook_path="../path/to/my_notebook", + render_deck=True, inputs=kwtypes(v=int), outputs=kwtypes(x=int, y=str), metadata=TaskMetadata(retries=3, cache=True, cache_version="1.0"), @@ -76,7 +78,7 @@ class NotebookTask(PythonInstanceTask[T]): #. It captures the executed notebook in its entirety and is available from Flyte with the name ``out_nb``. #. It also converts the captured notebook into an ``html`` page, which the FlyteConsole will render called - - ``out_rendered_nb`` + ``out_rendered_nb``. If ``render_deck=True`` is passed, this html content will be inserted into a deck. .. note: @@ -109,6 +111,7 @@ def __init__( self, name: str, notebook_path: str, + render_deck: bool = False, task_config: T = None, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, outputs: typing.Optional[typing.Dict[str, typing.Type]] = None, @@ -128,6 +131,8 @@ def __init__( task_type = f"nb-{self._config_task_instance.task_type}" self._notebook_path = os.path.abspath(notebook_path) + self._render_deck = render_deck + if not os.path.exists(self._notebook_path): raise ValueError(f"Illegal notebook path passed in {self._notebook_path}") @@ -225,6 +230,15 @@ def execute(self, **kwargs) -> Any: return tuple(output_list) def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + if self._render_deck: + nb_deck = Deck(self._IMPLICIT_RENDERED_NOTEBOOK) + with open(self.rendered_output_path, "r") as f: + notebook_html = f.read() + nb_deck.append(notebook_html) + # Since user_params is passed by reference, this modifies the object in the outside scope + # which then causes the deck to be rendered later during the dispatch_execute function. + user_params.decks.append(nb_deck) + return self._config_task_instance.post_execute(user_params, rval) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index ca25eea028..d60e68cdb0 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -69,3 +69,17 @@ def test_notebook_task_complex(): assert nb.python_interface.outputs.keys() == {"h", "w", "x", "out_nb", "out_rendered_nb"} assert nb.output_notebook_path == out == _get_nb_path(nb_name, suffix="-out") assert nb.rendered_output_path == render == _get_nb_path(nb_name, suffix="-out", ext=".html") + + +def test_notebook_deck_local_execution_doesnt_fail(): + nb_name = "nb-simple" + nb = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + render_deck=True, + inputs=kwtypes(pi=float), + outputs=kwtypes(square=float), + ) + sqr, out, render = nb.execute(pi=4) + # This is largely a no assert test to ensure render_deck never inhibits local execution. + assert nb._render_deck, "Passing render deck to init should result in private attribute being set" From 79dcc9553a658062b7a06006266b883b135d943d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 5 Aug 2022 11:23:17 -0700 Subject: [PATCH 02/27] Run compilation even in local execution for dynamic tasks to early detect errors (#1121) Signed-off-by: Yee Hing Tong --- flytekit/core/python_function_task.py | 11 ++++- tests/flytekit/unit/core/test_dynamic.py | 48 ++++++++++++++------ tests/flytekit/unit/core/test_local_cache.py | 10 ++-- tests/flytekit/unit/core/test_type_engine.py | 6 +-- 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 84a8eaedef..e3a10afdf3 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -19,6 +19,8 @@ from enum import Enum from typing import Any, Callable, List, Optional, TypeVar, Union +from flytekit.configuration import SerializationSettings +from flytekit.configuration.default_images import DefaultImages from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring @@ -257,10 +259,17 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() + # This is a placeholder SerializationSettings placeholder and is only used to test compilation for dynamic tasks + # when run locally. The output of the compilation should never actually be used anywhere. + _LOCAL_ONLY_SS = SerializationSettings.for_image(DefaultImages.default_image(), "v", "p", "d") if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: updated_exec_state = ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION) - with FlyteContextManager.with_context(ctx.with_execution_state(updated_exec_state)): + with FlyteContextManager.with_context( + ctx.with_execution_state(updated_exec_state).with_serialization_settings(_LOCAL_ONLY_SS) + ) as ctx: + logger.debug(f"Running compilation for {self} as part of local run as check") + self.compile_into_workflow(ctx, task_function, **kwargs) logger.info("Executing Dynamic workflow, using raw inputs") return exception_scopes.user_entry_point(task_function)(**kwargs) diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 365ce4c25f..668ca97dfd 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -1,5 +1,7 @@ import typing +import pytest + import flytekit.configuration from flytekit import dynamic from flytekit.configuration import FastSerializationSettings, Image, ImageConfig @@ -10,6 +12,19 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + fast_serialization_settings=FastSerializationSettings( + enabled=True, + destination_dir="/User/flyte/workflows", + distribution_location="s3://my-s3-bucket/fast/123", + ), +) + def test_wf1_with_fast_dynamic(): @task @@ -30,20 +45,7 @@ def my_wf(a: int) -> typing.List[str]: return v with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( - flytekit.configuration.SerializationSettings( - project="test_proj", - domain="test_domain", - version="abc", - image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), - env={}, - fast_serialization_settings=FastSerializationSettings( - enabled=True, - destination_dir="/User/flyte/workflows", - distribution_location="s3://my-s3-bucket/fast/123", - ), - ) - ) + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) ) as ctx: with context_manager.FlyteContextManager.with_context( ctx.with_execution_state( @@ -111,6 +113,24 @@ def wf(a: int, b: int) -> typing.List[str]: assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"] +def test_dynamic_local_use(): + @task + def t1(a: int) -> str: + a = a + 2 + return "fast-" + str(a) + + @dynamic + def use_result(a: int) -> int: + x = t1(a=a) + if len(x) > 6: + return 5 + else: + return 0 + + with pytest.raises(TypeError): + use_result(a=6) + + def test_create_node_dynamic_local(): @task def task1(s: str) -> str: diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 3f3e56de88..fe09fac830 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -8,7 +8,9 @@ from pytest import fixture from typing_extensions import Annotated -from flytekit import SQLTask, dynamic, kwtypes +from flytekit.core.base_sql_task import SQLTask +from flytekit.core.base_task import kwtypes +from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.hash import HashMethod from flytekit.core.local_cache import LocalTaskCache from flytekit.core.task import TaskMetadata, task @@ -309,13 +311,13 @@ def t1(a: int) -> int: # We should have a cache miss in the first call to downstream_t and have a cache hit # on the second call. - v_1 = downstream_t(a=v) + downstream_t(a=v) v_2 = downstream_t(a=v) - return v_1 + v_2 + return v_2 assert n_cached_task_calls == 0 - assert t1(a=3) == (6 + 6) + assert t1(a=3) == 6 assert n_cached_task_calls == 1 diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index df8e14d7cb..4b1c02134c 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1197,11 +1197,11 @@ def t1(a: int) -> int: # We should have a cache miss in the first call to downstream_t v_1 = downstream_t(a=v, df=df) - v_2 = downstream_t(a=v, df=df) + downstream_t(a=v, df=df) - return v_1 + v_2 + return v_1 - assert t1(a=3) == (6 + 6 + 6) + assert t1(a=3) == 9 def test_literal_hash_int_not_set(): From 4a07642a5a713a86fe31d11dc2e140715187cf12 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 8 Aug 2022 09:31:31 -0700 Subject: [PATCH 03/27] Set to pyflyte run blob object remote when dealing with remote files (#1128) Signed-off-by: Yee Hing Tong Signed-off-by: Eduardo Apolinario --- flytekit/clis/sdk_in_container/run.py | 2 +- tests/flytekit/unit/cli/pyflyte/test_run.py | 86 ++++++++++++++++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 6fece1e7d2..95533fb4d5 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -88,7 +88,7 @@ def convert( self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] ) -> typing.Any: if FileAccessProvider.is_remote(value): - return FileParam(filepath=value) + return FileParam(filepath=value, local=False) p = pathlib.Path(value) if p.exists() and p.is_file(): return FileParam(filepath=str(p.resolve())) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index ec35f5362d..9d09d58ee8 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,11 +1,22 @@ import os import pathlib +import mock import pytest from click.testing import CliRunner from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.run import get_entities_in_file +from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE +from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY +from flytekit.clis.sdk_in_container.run import ( + REMOTE_FLAG_KEY, + RUN_LEVEL_PARAMS_KEY, + FileParamType, + get_entities_in_file, + run_command, +) +from flytekit.configuration import Image, ImageConfig +from flytekit.core.task import task WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -172,3 +183,76 @@ def test_list_default_arguments(wf_path): ) print(result.stdout) assert result.exit_code == 0 + + +# default case, what comes from click if no image is specified, the click param is configured to use the default. +ic_result_1 = ImageConfig( + default_image=Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest"), + images=[Image(name="default", fqn="ghcr.io/flyteorg/mydefault", tag="py3.9-latest")], +) +# test that command line args are merged with the file +ic_result_2 = ImageConfig( + default_image=None, + images=[ + Image(name="asdf", fqn="ghcr.io/asdf/asdf", tag="latest"), + Image(name="xyz", fqn="docker.io/xyz", tag="latest"), + Image(name="abc", fqn="docker.io/abc", tag=None), + ], +) +# test that command line args override the file +ic_result_3 = ImageConfig( + default_image=None, + images=[Image(name="xyz", fqn="ghcr.io/asdf/asdf", tag="latest"), Image(name="abc", fqn="docker.io/abc", tag=None)], +) + + +@pytest.mark.parametrize( + "image_string, leaf_configuration_file_name, final_image_config", + [ + ("ghcr.io/flyteorg/mydefault:py3.9-latest", "no_images.yaml", ic_result_1), + ("asdf=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_2), + ("xyz=ghcr.io/asdf/asdf:latest", "sample.yaml", ic_result_3), + ], +) +def test_pyflyte_run_run(image_string, leaf_configuration_file_name, final_image_config): + @task + def a(): + ... + + mock_click_ctx = mock.MagicMock() + mock_remote = mock.MagicMock() + image_tuple = (image_string,) + image_config = ImageConfig.validate_image(None, "", image_tuple) + + run_level_params = { + "project": "p", + "domain": "d", + "image_config": image_config, + } + + pp = pathlib.Path.joinpath( + pathlib.Path(__file__).parent.parent.parent, "configuration/configs/", leaf_configuration_file_name + ) + + obj = { + RUN_LEVEL_PARAMS_KEY: run_level_params, + REMOTE_FLAG_KEY: True, + FLYTE_REMOTE_INSTANCE_KEY: mock_remote, + CTX_CONFIG_FILE: str(pp), + } + mock_click_ctx.obj = obj + + def check_image(*args, **kwargs): + assert kwargs["image_config"] == final_image_config + + mock_remote.register_script.side_effect = check_image + + run_command(mock_click_ctx, a)() + + +def test_file_param(): + m = mock.MagicMock() + l = FileParamType().convert(__file__, m, m) + assert l.local + r = FileParamType().convert("https://tmp/file", m, m) + assert r.local is False From c2dbd8098de8470e553d0095988c6b6af0ec1dc6 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 9 Aug 2022 01:08:07 -0700 Subject: [PATCH 04/27] Override voidPromise resource (#1127) * override void promise resource Signed-off-by: Kevin Su * override void promise resource Signed-off-by: Kevin Su --- flytekit/core/promise.py | 5 ++++ .../flytekit/unit/core/test_node_creation.py | 24 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 4c9150881d..4fe8e669ab 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -679,6 +679,11 @@ def __rshift__(self, other: typing.Union[Promise, VoidPromise]): if self.ref: self.ref.node.runs_before(other.ref.node) + def with_overrides(self, *args, **kwargs): + if self.ref: + self.ref.node.with_overrides(*args, **kwargs) + return self + @property def task_name(self): return self._task_name diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index a303230386..f6dc9c9ba5 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -372,3 +372,27 @@ def my_wf(a: str) -> str: wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) assert len(wf_spec.template.nodes) == 1 assert wf_spec.template.nodes[0].metadata.interruptible == interruptible + + +def test_void_promise_override(): + @task + def t1(a: str): + print(f"*~*~*~{a}*~*~*~") + + @workflow + def my_wf(a: str): + t1(a=a).with_overrides(requests=Resources(cpu="1", mem="100")) + + serialization_settings = flytekit.configuration.SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) + assert len(wf_spec.template.nodes) == 1 + assert wf_spec.template.nodes[0].task_node.overrides.resources.requests == [ + _resources_models.ResourceEntry(_resources_models.ResourceName.CPU, "1"), + _resources_models.ResourceEntry(_resources_models.ResourceName.MEMORY, "100"), + ] From 1dc660919c1a72e6dcd81a70a40ef3c149346f68 Mon Sep 17 00:00:00 2001 From: Matheus Moreno Date: Tue, 16 Aug 2022 10:28:50 -0300 Subject: [PATCH 05/27] Fix how ShellTask retrieves the Pod class name (#1132) * Fix how ShellTask retrieves the Pod class name Signed-off-by: Matheus Moreno * Set Pod class name as a constant Signed-off-by: Matheus Moreno * Revert last commit Signed-off-by: Matheus Moreno * Execute automatic linting Signed-off-by: Matheus Moreno Signed-off-by: Matheus Moreno --- flytekit/extras/tasks/shell.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 812c0a3749..be7cda0a17 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -132,7 +132,8 @@ def __init__( script_file = os.path.abspath(script_file) if task_config is not None: - if str(type(task_config)) != "flytekitplugins.pod.task.Pod": + fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__ + if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod": raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used From bdd023ff52e59ee1d62e718c8d08ab55d3870f43 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 17 Aug 2022 13:05:49 -0700 Subject: [PATCH 06/27] Add restriction for pandas to be >=1.2 for fsspec plugin (#1136) Signed-off-by: Yee Hing Tong --- plugins/flytekit-data-fsspec/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index 3678a0b518..3756b9228b 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "fsspec>=2021.7.0", "botocore>=1.7.48"] +plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "fsspec>=2021.7.0", "botocore>=1.7.48", "pandas>=1.2.0"] __version__ = "0.0.0+develop" From f51b15d2f0970bca2ca24ff4baf4155fdf8c49f0 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 17 Aug 2022 15:34:07 -0700 Subject: [PATCH 07/27] Use joblib hashing to generate cache key to ensure repeatability (#1126) * cherry pick 97b454b1 Signed-off-by: Yee Hing Tong * requirements Signed-off-by: Yee Hing Tong * Fix usage of save in ProtoJoblibHasher Signed-off-by: Eduardo Apolinario * Regenerate requirements using python 3.7 Signed-off-by: Eduardo Apolinario * Add test_stable_cache_key Signed-off-by: Eduardo Apolinario Signed-off-by: Yee Hing Tong Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- dev-requirements.txt | 58 +++++----- doc-requirements.txt | 101 ++++++++++-------- flytekit/core/local_cache.py | 21 +++- requirements-spark2.txt | 18 ++-- requirements.txt | 18 ++-- setup.py | 1 + .../workflows/requirements.txt | 27 ++--- tests/flytekit/unit/core/test_local_cache.py | 57 +++++++++- 8 files changed, 190 insertions(+), 111 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 645ffbeb18..b477f2553b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -65,7 +65,7 @@ cookiecutter==2.1.1 # via # -c requirements.txt # flytekit -coverage[toml]==6.4.1 +coverage[toml]==6.4.4 # via # -r dev-requirements.in # pytest-cov @@ -96,7 +96,7 @@ diskcache==5.4.0 # via # -c requirements.txt # flytekit -distlib==0.3.4 +distlib==0.3.5 # via virtualenv distro==1.7.0 # via docker-compose @@ -106,9 +106,7 @@ docker[ssh]==5.0.3 # docker-compose # flytekit docker-compose==1.29.2 - # via - # pytest-docker - # pytest-flyte + # via pytest-flyte docker-image-py==0.1.12 # via # -c requirements.txt @@ -121,9 +119,9 @@ docstring-parser==0.14.1 # via # -c requirements.txt # flytekit -filelock==3.7.1 +filelock==3.8.0 # via virtualenv -flyteidl==1.1.8 +flyteidl==1.1.12 # via # -c requirements.txt # flytekit @@ -132,23 +130,23 @@ google-api-core[grpc]==2.8.2 # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.9.0 +google-auth==2.10.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.2.0 +google-cloud-bigquery==3.3.1 # via -r dev-requirements.in -google-cloud-bigquery-storage==2.14.0 +google-cloud-bigquery-storage==2.14.2 # via # -r dev-requirements.in # google-cloud-bigquery -google-cloud-core==2.3.1 +google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # -c requirements.txt # flyteidl @@ -166,7 +164,7 @@ grpcio-status==1.47.0 # -c requirements.txt # flytekit # google-api-core -identify==2.5.1 +identify==2.5.3 # via pre-commit idna==3.3 # via @@ -205,12 +203,15 @@ jinja2-time==0.2.0 # -c requirements.txt # cookiecutter joblib==1.1.0 - # via -r dev-requirements.in + # via + # -c requirements.txt + # -r dev-requirements.in + # flytekit jsonschema==3.2.0 # via # -c requirements.txt # docker-compose -keyring==23.6.0 +keyring==23.8.2 # via # -c requirements.txt # flytekit @@ -232,11 +233,11 @@ marshmallow-jsonschema==0.13.0 # via # -c requirements.txt # flytekit -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.5 # via ipython mock==4.0.3 # via -r dev-requirements.in -mypy==0.961 +mypy==0.971 # via -r dev-requirements.in mypy-extensions==0.4.3 # via @@ -281,7 +282,7 @@ pre-commit==2.20.0 # via -r dev-requirements.in prompt-toolkit==3.0.30 # via ipython -proto-plus==1.20.6 +proto-plus==1.22.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -323,7 +324,7 @@ pycparser==2.21 # via # -c requirements.txt # cffi -pygments==2.12.0 +pygments==2.13.0 # via ipython pynacl==1.5.0 # via paramiko @@ -347,7 +348,7 @@ pytest==7.1.2 # pytest-flyte pytest-cov==3.0.0 # via -r dev-requirements.in -pytest-docker==0.12.0 +pytest-docker==1.0.0 # via pytest-flyte pytest-flyte @ git+https://github.com/flyteorg/pytest-flyte@main # via -r dev-requirements.in @@ -373,7 +374,7 @@ pytimeparse==1.1.8 # via # -c requirements.txt # flytekit -pytz==2022.1 +pytz==2022.2.1 # via # -c requirements.txt # flytekit @@ -385,7 +386,7 @@ pyyaml==5.4.1 # docker-compose # flytekit # pre-commit -regex==2022.7.9 +regex==2022.7.25 # via # -c requirements.txt # docker-image-py @@ -407,9 +408,9 @@ retry==0.9.2 # via # -c requirements.txt # flytekit -rsa==4.8 +rsa==4.9 # via google-auth -secretstorage==3.3.2 +secretstorage==3.3.3 # via # -c requirements.txt # keyring @@ -426,7 +427,6 @@ six==1.16.0 # jsonschema # paramiko # python-dateutil - # virtualenv # websocket-client sortedcontainers==2.4.0 # via @@ -449,7 +449,7 @@ tomli==2.0.1 # coverage # mypy # pytest -torch==1.12.0 +torch==1.12.1 # via -r dev-requirements.in traitlets==5.3.0 # via @@ -471,13 +471,13 @@ typing-inspect==0.7.1 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.10 +urllib3==1.26.11 # via # -c requirements.txt # flytekit # requests # responses -virtualenv==20.15.1 +virtualenv==20.16.3 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit @@ -495,7 +495,7 @@ wrapt==1.14.1 # -c requirements.txt # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via # -c requirements.txt # importlib-metadata diff --git a/doc-requirements.txt b/doc-requirements.txt index e9f16c89a0..5bfb8c5d31 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -18,9 +18,9 @@ argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.2.2 # via jinja2-time -astroid==2.11.6 +astroid==2.12.2 # via sphinx-autoapi -attrs==21.4.0 +attrs==22.1.0 # via # jsonschema # visions @@ -42,7 +42,7 @@ binaryornot==0.4.4 # via cookiecutter bleach==5.0.1 # via nbconvert -botocore==1.27.22 +botocore==1.27.53 # via -r doc-requirements.in cachetools==5.2.0 # via google-auth @@ -86,7 +86,7 @@ dataclasses-json==0.5.7 # via # dolt-integrations # flytekit -debugpy==1.6.0 +debugpy==1.6.3 # via ipykernel decorator==5.1.1 # via @@ -118,13 +118,13 @@ entrypoints==0.4 # jupyter-client # nbconvert # papermill -fastjsonschema==2.15.3 +fastjsonschema==2.16.1 # via nbformat -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -fonttools==4.33.3 +fonttools==4.35.0 # via matplotlib -fsspec==2022.5.0 +fsspec==2022.7.1 # via # -r doc-requirements.in # modin @@ -135,29 +135,29 @@ google-api-core[grpc]==2.8.2 # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.9.0 +google-auth==2.10.0 # via # google-api-core # google-cloud-core # kubernetes google-cloud==0.34.0 # via -r doc-requirements.in -google-cloud-bigquery==3.2.0 +google-cloud-bigquery==3.3.1 # via -r doc-requirements.in -google-cloud-bigquery-storage==2.13.2 +google-cloud-bigquery-storage==2.14.2 # via google-cloud-bigquery -google-cloud-core==2.3.1 +google-cloud-core==2.3.2 # via google-cloud-bigquery google-crc32c==1.3.0 # via google-resumable-media google-resumable-media==2.3.3 # via google-cloud-bigquery -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # google-api-core # grpcio-status -great-expectations==0.15.12 +great-expectations==0.15.18 # via -r doc-requirements.in greenlet==1.1.2 # via sqlalchemy @@ -190,9 +190,9 @@ importlib-metadata==4.12.0 # markdown # sphinx # sqlalchemy -importlib-resources==5.8.0 +importlib-resources==5.9.0 # via jsonschema -ipykernel==6.15.0 +ipykernel==6.15.1 # via # ipywidgets # jupyter @@ -211,7 +211,9 @@ ipython-genutils==0.2.0 # notebook # qtconsole ipywidgets==7.7.1 - # via jupyter + # via + # great-expectations + # jupyter jedi==0.18.1 # via ipython jeepney==0.8.0 @@ -235,13 +237,14 @@ jmespath==1.0.1 # via botocore joblib==1.1.0 # via + # flytekit # pandas-profiling # phik jsonpatch==1.32 # via great-expectations jsonpointer==2.3 # via jsonpatch -jsonschema==4.6.1 +jsonschema==4.10.0 # via # altair # great-expectations @@ -257,7 +260,7 @@ jupyter-client==7.3.4 # qtconsole jupyter-console==6.4.4 # via jupyter -jupyter-core==4.10.0 +jupyter-core==4.11.1 # via # jupyter-client # nbconvert @@ -268,17 +271,21 @@ jupyterlab-pygments==0.2.2 # via nbconvert jupyterlab-widgets==1.1.1 # via ipywidgets -keyring==23.6.0 +keyring==23.8.2 # via flytekit -kiwisolver==1.4.3 +kiwisolver==1.4.4 # via matplotlib kubernetes==24.2.0 # via -r doc-requirements.in lazy-object-proxy==1.7.1 # via astroid lxml==4.9.1 - # via sphinx-material -markdown==3.3.7 + # via + # nbconvert + # sphinx-material +makefun==1.14.0 + # via great-expectations +markdown==3.4.1 # via -r doc-requirements.in markupsafe==2.1.1 # via @@ -294,13 +301,13 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.2 +matplotlib==3.5.3 # via # missingno # pandas-profiling # phik # seaborn -matplotlib-inline==0.1.3 +matplotlib-inline==0.1.5 # via # ipykernel # ipython @@ -324,7 +331,7 @@ nbclient==0.6.6 # via # nbconvert # papermill -nbconvert==6.5.0 +nbconvert==6.5.3 # via # jupyter # notebook @@ -398,7 +405,7 @@ pandera==0.9.0 # via -r doc-requirements.in pandocfilters==1.5.0 # via nbconvert -papermill==2.3.4 +papermill==2.4.0 # via -r doc-requirements.in parso==0.8.3 # via jedi @@ -413,7 +420,9 @@ pillow==9.2.0 # imagehash # matplotlib # visions -plotly==5.9.0 +pkgutil-resolve-name==1.3.10 + # via jsonschema +plotly==5.10.0 # via -r doc-requirements.in prometheus-client==0.14.1 # via notebook @@ -421,7 +430,7 @@ prompt-toolkit==3.0.30 # via # ipython # jupyter-console -proto-plus==1.20.6 +proto-plus==1.22.0 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -461,11 +470,11 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.9.1 +pydantic==1.9.2 # via # pandas-profiling # pandera -pygments==2.12.0 +pygments==2.13.0 # via # furo # ipython @@ -497,7 +506,7 @@ python-dateutil==2.8.2 # kubernetes # matplotlib # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.4 # via flytekit python-slugify[unidecode]==6.1.2 # via @@ -505,7 +514,7 @@ python-slugify[unidecode]==6.1.2 # sphinx-material pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # babel # flytekit @@ -523,7 +532,7 @@ pyyaml==6.0 # pandas-profiling # papermill # sphinx-autoapi -pyzmq==23.2.0 +pyzmq==23.2.1 # via # ipykernel # jupyter-client @@ -531,9 +540,9 @@ pyzmq==23.2.0 # qtconsole qtconsole==5.3.1 # via jupyter -qtpy==2.1.0 +qtpy==2.2.0 # via qtconsole -regex==2022.6.2 +regex==2022.7.25 # via docker-image-py requests==2.28.1 # via @@ -555,7 +564,7 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -rsa==4.8 +rsa==4.9 # via google-auth ruamel-yaml==0.17.17 # via great-expectations @@ -573,7 +582,7 @@ seaborn==0.11.2 # via # missingno # pandas-profiling -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring send2trash==1.8.0 # via notebook @@ -608,7 +617,7 @@ sphinx==4.5.0 # sphinx-panels # sphinx-prompt # sphinxcontrib-yt -sphinx-autoapi==1.8.4 +sphinx-autoapi==1.9.0 # via -r doc-requirements.in sphinx-basic-ng==0.0.1a12 # via furo @@ -618,7 +627,7 @@ sphinx-copybutton==0.5.0 # via -r doc-requirements.in sphinx-fontawesome==0.0.6 # via -r doc-requirements.in -sphinx-gallery==0.10.1 +sphinx-gallery==0.11.0 # via -r doc-requirements.in sphinx-material==0.0.35 # via -r doc-requirements.in @@ -640,7 +649,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-yt==0.2.2 # via -r doc-requirements.in -sqlalchemy==1.4.39 +sqlalchemy==1.4.40 # via -r doc-requirements.in statsd==3.3.0 # via flytekit @@ -662,9 +671,9 @@ textwrap3==0.9.2 # via ansiwrap tinycss2==1.1.1 # via nbconvert -toolz==0.11.2 +toolz==0.12.0 # via altair -torch==1.11.0 +torch==1.12.1 # via -r doc-requirements.in tornado==6.2 # via @@ -711,7 +720,7 @@ typing-inspect==0.7.1 # via # dataclasses-json # pandera -tzdata==2022.1 +tzdata==2022.2 # via pytz-deprecation-shim tzlocal==4.2 # via great-expectations @@ -719,7 +728,7 @@ unidecode==1.3.4 # via # python-slugify # sphinx-autoapi -urllib3==1.26.9 +urllib3==1.26.11 # via # botocore # flytekit @@ -749,7 +758,7 @@ wrapt==1.14.1 # deprecated # flytekit # pandera -zipp==3.8.0 +zipp==3.8.1 # via # importlib-metadata # importlib-resources diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 4afab73955..48b6f9c7da 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -1,8 +1,8 @@ -import base64 from typing import Optional -import cloudpickle from diskcache import Cache +from google.protobuf.struct_pb2 import Struct +from joblib.hashing import NumpyHasher from flytekit.models.literals import Literal, LiteralCollection, LiteralMap @@ -28,15 +28,26 @@ def _recursive_hash_placement(literal: Literal) -> Literal: return literal +class ProtoJoblibHasher(NumpyHasher): + def save(self, obj): + if isinstance(obj, Struct): + obj = dict( + rewrite_rule="google.protobuf.struct_pb2.Struct", + cls=obj.__class__, + obj=dict(sorted(obj.fields.items())), + ) + NumpyHasher.save(self, obj) + + def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str: # Traverse the literals and replace the literal with a new literal that only contains the hash literal_map_overridden = {} for key, literal in input_literal_map.literals.items(): literal_map_overridden[key] = _recursive_hash_placement(literal) - # Pickle the literal map and use base64 encoding to generate a representation of it - b64_encoded = base64.b64encode(cloudpickle.dumps(LiteralMap(literal_map_overridden))) - return f"{task_name}-{cache_version}-{b64_encoded}" + # Generate a hash key of inputs with joblib + hashed_inputs = ProtoJoblibHasher().hash(literal_map_overridden) + return f"{task_name}-{cache_version}-{hashed_inputs}" class LocalTaskCache(object): diff --git a/requirements-spark2.txt b/requirements-spark2.txt index 8398d30d1e..ded88f1bb2 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -52,9 +52,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -82,9 +82,11 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +joblib==1.1.0 + # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 @@ -146,7 +148,7 @@ python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -155,7 +157,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.9 +regex==2022.7.25 # via docker-image-py requests==2.28.1 # via @@ -167,7 +169,7 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -192,7 +194,7 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.7.1 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.11 # via # flytekit # requests @@ -207,7 +209,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements.txt b/requirements.txt index 153da8b4d2..17a6487f7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,9 +50,9 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -80,9 +80,11 @@ jinja2==3.1.2 # jinja2-time jinja2-time==0.2.0 # via cookiecutter +joblib==1.1.0 + # via flytekit jsonschema==3.2.0 # via -r requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 @@ -144,7 +146,7 @@ python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -153,7 +155,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.9 +regex==2022.7.25 # via docker-image-py requests==2.28.1 # via @@ -165,7 +167,7 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -190,7 +192,7 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.7.1 # via dataclasses-json -urllib3==1.26.10 +urllib3==1.26.11 # via # flytekit # requests @@ -205,7 +207,7 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: diff --git a/setup.py b/setup.py index 85af8691b6..56e2dd5624 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,7 @@ "grpcio-status>=1.43,!=1.45.0", "importlib-metadata", "pyopenssl", + "joblib", "protobuf>=3.6.1,<4", "python-json-logger>=2.0.0", "pytimeparse>=1.1.8,<2.0.0", diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index c93c56435c..0fc659135f 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -46,13 +46,13 @@ docker-image-py==0.1.12 # via flytekit docstring-parser==0.14.1 # via flytekit -flyteidl==1.1.8 +flyteidl==1.1.12 # via flytekit -flytekit==1.1.0 +flytekit==1.1.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.33.3 +fonttools==4.35.0 # via matplotlib -googleapis-common-protos==1.56.3 +googleapis-common-protos==1.56.4 # via # flyteidl # grpcio-status @@ -81,9 +81,9 @@ jinja2-time==0.2.0 # via cookiecutter joblib==1.1.0 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -keyring==23.6.0 +keyring==23.8.2 # via flytekit -kiwisolver==1.4.3 +kiwisolver==1.4.4 # via matplotlib markupsafe==2.1.1 # via jinja2 @@ -96,7 +96,7 @@ marshmallow-enum==1.5.1 # via dataclasses-json marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.2 +matplotlib==3.5.3 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in mypy-extensions==0.4.3 # via typing-inspect @@ -142,17 +142,18 @@ pyparsing==3.0.9 # packaging python-dateutil==2.8.2 # via + # arrow # croniter # flytekit # matplotlib # pandas -python-json-logger==2.0.2 +python-json-logger==2.0.4 # via flytekit python-slugify==6.1.2 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.1 +pytz==2022.2.1 # via # flytekit # pandas @@ -160,7 +161,7 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.6.2 +regex==2022.7.25 # via docker-image-py requests==2.28.1 # via @@ -172,7 +173,7 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.2 +secretstorage==3.3.3 # via keyring singledispatchmethod==1.0 # via flytekit @@ -196,7 +197,7 @@ typing-extensions==4.3.0 # typing-inspect typing-inspect==0.7.1 # via dataclasses-json -urllib3==1.26.9 +urllib3==1.26.11 # via # flytekit # requests @@ -211,5 +212,5 @@ wrapt==1.14.1 # via # deprecated # flytekit -zipp==3.8.0 +zipp==3.8.1 # via importlib-metadata diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index fe09fac830..674f6176e1 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -1,7 +1,7 @@ import datetime import typing from dataclasses import dataclass -from typing import List +from typing import Dict, List import pandas from dataclasses_json import dataclass_json @@ -10,12 +10,16 @@ from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import kwtypes +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.hash import HashMethod -from flytekit.core.local_cache import LocalTaskCache +from flytekit.core.local_cache import LocalTaskCache, _calculate_cache_key from flytekit.core.task import TaskMetadata, task from flytekit.core.testing import task_mock +from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.models.literals import LiteralMap +from flytekit.models.types import LiteralType, SimpleType from flytekit.types.schema import FlyteSchema # Global counter used to validate number of calls to cache @@ -385,3 +389,52 @@ def my_workflow(): # Confirm that we see a cache hit in the case of annotated dataframes. my_workflow() assert n_cached_task_calls == 1 + + +def test_cache_key_repetition(): + pt = Dict + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContextManager.current_context() + kwargs = { + "a": 0.41083513079747874, + "b": 0.7773927872515183, + "c": 17, + } + keys = set() + for i in range(0, 100): + lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) + lm = LiteralMap( + literals={ + "d": lit, + } + ) + key = _calculate_cache_key("t1", "007", lm) + keys.add(key) + + assert len(keys) == 1 + + +def test_stable_cache_key(): + """ + The intent of this test is to ensure cache keys are stable across releases and python versions. + """ + pt = Dict + lt = TypeEngine.to_literal_type(pt) + ctx = FlyteContextManager.current_context() + kwargs = { + "a": 42, + "b": "abcd", + "c": 0.12349, + "d": [1, 2, 3], + } + lit = TypeEngine.to_literal(ctx, kwargs, Dict, lt) + lm = LiteralMap( + literals={ + "lit_1": lit, + "lit_2": TypeEngine.to_literal(ctx, 99, int, LiteralType(simple=SimpleType.INTEGER)), + "lit_3": TypeEngine.to_literal(ctx, 3.14, float, LiteralType(simple=SimpleType.FLOAT)), + "lit_4": TypeEngine.to_literal(ctx, True, bool, LiteralType(simple=SimpleType.BOOLEAN)), + } + ) + key = _calculate_cache_key("task_name_1", "31415", lm) + assert key == "task_name_1-31415-a291dc6fe0be387c1cfd67b4c6b78259" From ad5b19bd75df9761676bcc803890d6ec3f57a9f8 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 22 Aug 2022 10:04:20 -0700 Subject: [PATCH 08/27] Allow None protocol to mean all data persistence supported storage options in Structured Dataset (#1134) Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 4 + flytekit/types/structured/basic_dfs.py | 33 ++---- flytekit/types/structured/bigquery.py | 11 +- .../types/structured/structured_dataset.py | 111 ++++++++++++------ .../flytekitplugins/fsspec/__init__.py | 6 +- .../flytekitplugins/fsspec/pandas.py | 3 +- .../flytekit-papermill/dev-requirements.in | 2 +- .../flytekit-papermill/dev-requirements.txt | 2 +- .../flytekitplugins/polars/sd_transformers.py | 21 +--- .../flytekitplugins/spark/sd_transformers.py | 17 +-- .../unit/core/test_data_persistence.py | 9 +- .../unit/core/test_structured_dataset.py | 51 +++++--- .../core/test_structured_dataset_handlers.py | 15 ++- .../test_structured_dataset_workflow.py | 9 +- 14 files changed, 176 insertions(+), 118 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index e69e3f6476..a2ad5311f1 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -171,6 +171,10 @@ def is_supported_protocol(cls, protocol: str) -> bool: """ return protocol in cls._PLUGINS + @classmethod + def supported_protocols(cls) -> typing.List[str]: + return [k for k in cls._PLUGINS.keys()] + class DiskPersistence(DataPersistence): """ diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 82680d2787..97964d0b63 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -7,16 +7,11 @@ import pyarrow.parquet as pq from flytekit import FlyteContext -from flytekit.core.data_persistence import DataPersistencePlugins from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -27,10 +22,8 @@ class PandasToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - # todo: Use this somehow instead of relaying ont he ctx file_access - self._persistence = DataPersistencePlugins.find_plugin(protocol)() + def __init__(self): + super().__init__(pd.DataFrame, None, PARQUET) def encode( self, @@ -50,8 +43,8 @@ def encode( class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pd.DataFrame, None, PARQUET) def decode( self, @@ -69,8 +62,8 @@ def decode( class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) + def __init__(self): + super().__init__(pa.Table, None, PARQUET) def encode( self, @@ -88,8 +81,8 @@ def encode( class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) + def __init__(self): + super().__init__(pa.Table, None, PARQUET) def decode( self, @@ -106,9 +99,7 @@ def decode( return pq.read_table(local_dir) -# Don't override default protocol -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False) +StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) +StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 92d203e25d..85cede1544 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -10,7 +10,6 @@ from flytekit.models import literals from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - BIGQUERY, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -18,6 +17,8 @@ StructuredDatasetTransformerEngine, ) +BIGQUERY = "bq" + def _write_to_bq(structured_dataset: StructuredDataset): table_id = typing.cast(str, structured_dataset.uri).split("://", 1)[1].replace(":", ".") @@ -111,7 +112,7 @@ def decode( return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata)) -StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False) -StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False) -StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False) -StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False) +StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers()) +StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler()) +StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers()) +StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index cdb26a87c2..bfbc494bbb 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -3,7 +3,6 @@ import collections import importlib import os -import re import types import typing from abc import ABC, abstractmethod @@ -12,15 +11,16 @@ import _datetime import numpy as _np -import pandas import pandas as pd -import pyarrow import pyarrow as pa +from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence + if importlib.util.find_spec("pyspark") is not None: import pyspark if importlib.util.find_spec("polars") is not None: import polars as pl + from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin @@ -36,13 +36,6 @@ T = typing.TypeVar("T") # StructuredDataset type or a dataframe type DF = typing.TypeVar("DF") # Dataframe type -# Protocols -BIGQUERY = "bq" -S3 = "s3" -ABFS = "abfs" -GCS = "gs" -LOCAL = "/" - # For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy. StructuredDatasetFormat: TypeAlias = str @@ -156,7 +149,7 @@ def extract_cols_and_format( if ordered_dict_cols is not None: raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}") ordered_dict_cols = aa - elif isinstance(aa, pyarrow.Schema): + elif isinstance(aa, pa.Schema): if pa_schema is not None: raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}") pa_schema = aa @@ -168,7 +161,7 @@ def extract_cols_and_format( class StructuredDatasetEncoder(ABC): - def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None): """ Extend this abstract class, implement the encode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -179,12 +172,14 @@ def __init__(self, python_type: Type[T], protocol: str, supported_format: Option :param python_type: The dataframe class in question that you want to register this encoder with :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. + If None, this encoder will be registered with all protocols that flytekit's data persistence layer + is capable of handling. :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the encoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning. """ self._python_type = python_type - self._protocol = protocol.replace("://", "") + self._protocol = protocol.replace("://", "") if protocol else None self._supported_format = supported_format or "" @property @@ -192,7 +187,7 @@ def python_type(self) -> Type[T]: return self._python_type @property - def protocol(self) -> str: + def protocol(self) -> Optional[str]: return self._protocol @property @@ -228,7 +223,7 @@ def encode( class StructuredDatasetDecoder(ABC): - def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optional[str] = None): + def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None): """ Extend this abstract class, implement the decode function, and register your concrete class with the StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle @@ -238,12 +233,14 @@ def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optio :param python_type: The dataframe class in question that you want to register this decoder with :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor. + If None, this decoder will be registered with all protocols that flytekit's data persistence layer + is capable of handling. :param supported_format: Arbitrary string representing the format. If not supplied then an empty string will be used. An empty string implies that the decoder works with any format. If the format being asked for does not exist, the transformer enginer will look for the "" decoder instead and write a warning. """ self._python_type = python_type - self._protocol = protocol.replace("://", "") + self._protocol = protocol.replace("://", "") if protocol else None self._supported_format = supported_format or "" @property @@ -251,7 +248,7 @@ def python_type(self) -> Type[DF]: return self._python_type @property - def protocol(self) -> str: + def protocol(self) -> Optional[str]: return self._protocol @property @@ -281,10 +278,8 @@ def decode( def protocol_prefix(uri: str) -> str: - g = re.search(r"([\w]+)://.*", uri) - if g and g.groups(): - return g.groups()[0] - return LOCAL + p = DataPersistencePlugins.get_protocol(uri) + return p def convert_schema_type_to_structured_dataset_type( @@ -306,6 +301,10 @@ def convert_schema_type_to_structured_dataset_type( raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}") +class DuplicateHandlerError(ValueError): + ... + + class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): """ Think of this transformer as a higher-level meta transformer that is used for all the dataframe types. @@ -366,8 +365,7 @@ def get_decoder(cls, df_type: Type, protocol: str, format: str): return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format) @classmethod - def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]: - # Maybe think about default dict in the future, but is typing as nice? + def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]: if isinstance(h, StructuredDatasetEncoder): top_level = cls.ENCODERS elif isinstance(h, StructuredDatasetDecoder): @@ -376,9 +374,9 @@ def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]: raise TypeError(f"We don't support this type of handler {h}") if h.python_type not in top_level: top_level[h.python_type] = {} - if h.protocol not in top_level[h.python_type]: - top_level[h.python_type][h.protocol] = {} - return top_level[h.python_type][h.protocol] + if protocol not in top_level[h.python_type]: + top_level[h.python_type][protocol] = {} + return top_level[h.python_type][protocol] def __init__(self): super().__init__("StructuredDataset Transformer", StructuredDataset) @@ -388,22 +386,65 @@ def __init__(self): self._hash_overridable = True @classmethod - def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False): + def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): + """ + Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not + specify a protocol (e.g. s3, gs, etc.) field, then + + :param h: The StructuredDatasetEncoder or StructuredDatasetDecoder you wish to register with this transformer. + :param default_for_type: If set, when a user returns from a task an instance of the dataframe the handler + handles, e.g. ``return pd.DataFrame(...)``, not wrapped around the ``StructuredDataset`` object, we will + use this handler's protocol and format as the default, effectively saying that this handler will be called. + Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler + is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of. + In these cases, the protocol is determined by the raw output data prefix set in the active context. + :param override: Override any previous registrations. If default_for_type is also set, this will also override + the default. """ - Call this with any handler to register it with this dataframe meta-transformer + if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)): + raise TypeError(f"We don't support this type of handler {h}") - The string "://" should not be present in any handler's protocol so we don't check for it. + if h.protocol is None: + if default_for_type: + raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") + for persistence_protocol in DataPersistencePlugins.supported_protocols(): + # TODO: Clean this up when we get to replacing the persistence layer. + # The behavior of the protocols given in the supported_protocols and is_supported_protocol + # is not actually the same as the one returned in get_protocol. + stripped = DataPersistencePlugins.get_protocol(persistence_protocol) + logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") + try: + cls.register_for_protocol(h, stripped, False, override) + except DuplicateHandlerError: + logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") + + elif h.protocol == "": + raise ValueError(f"Use None instead of empty string for registering handler {h}") + else: + cls.register_for_protocol(h, h.protocol, default_for_type, override) + + @classmethod + def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool): + """ + See the main register function instead. """ - lowest_level = cls._handler_finder(h) + if protocol == "/": + # TODO: Special fix again, because get_protocol returns file, instead of file:// + protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: - raise ValueError(f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}") + raise DuplicateHandlerError( + f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}" + ) lowest_level[h.supported_format] = h - logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}") + logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}") if default_for_type: - # TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist. + logger.debug( + f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}" + ) cls.DEFAULT_FORMATS[h.python_type] = h.supported_format - cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol + cls.DEFAULT_PROTOCOLS[h.python_type] = protocol # Register with the type engine as well # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as @@ -657,7 +698,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ else: df = python_val - if isinstance(df, pandas.DataFrame): + if isinstance(df, pd.DataFrame): return df.describe().to_html() elif isinstance(df, pa.Table): return df.to_string() diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index e8d88f26c3..68ee456ed6 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -25,13 +25,15 @@ import importlib from flytekit import StructuredDatasetTransformerEngine, logger -from flytekit.configuration import internal -from flytekit.types.structured.structured_dataset import ABFS, GCS, S3 from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler from .persist import FSSpecPersistence +S3 = "s3" +ABFS = "abfs" +GCS = "gs" + def _register(protocol: str): logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.") diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py index 65a440b785..e4986ed9f6 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py @@ -13,7 +13,6 @@ from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -22,7 +21,7 @@ def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]: protocol = FSSpecPersistence.get_protocol(uri) - if protocol == S3: + if protocol == "s3": kwargs = s3_setup_args(cfg.s3) if kwargs: return kwargs diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index 98b7896e22..a57b6365fe 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,3 +1,3 @@ flyteidl>=1.0.0 -git+https://github.com/flyteorg/flytekit@master#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark +git+https://github.com/flyteorg/flytekit@sd-data-persistence#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark # vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 4b5cde2509..ba2d48ab1c 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -44,7 +44,7 @@ flyteidl==1.0.0.post1 # flytekit flytekit==1.1.0b0 # via flytekitplugins-spark -flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@master#subdirectory=plugins/flytekit-spark +flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@sd-data-persistence#subdirectory=plugins/flytekit-spark # via -r dev-requirements.in googleapis-common-protos==1.55.0 # via diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 06b1127504..6388bc4c9e 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -7,11 +7,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -20,8 +16,8 @@ class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pl.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pl.DataFrame, None, PARQUET) def encode( self, @@ -45,8 +41,8 @@ def encode( class ParquetToPolarsDataFrameDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pl.DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(pl.DataFrame, None, PARQUET) def decode( self, @@ -63,10 +59,5 @@ def decode( return pl.read_parquet(path) -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register( - PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False - ) - StructuredDatasetTransformerEngine.register( - ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=False - ) +StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 9fef590bcc..1a89b7b331 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -7,11 +7,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - ABFS, - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -20,8 +16,8 @@ class SparkToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(DataFrame, None, PARQUET) def encode( self, @@ -36,8 +32,8 @@ def encode( class ParquetToSparkDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(DataFrame, protocol, PARQUET) + def __init__(self): + super().__init__(DataFrame, None, PARQUET) def decode( self, @@ -52,6 +48,5 @@ def decode( return user_ctx.spark_session.read.parquet(flyte_value.uri) -for protocol in [LOCAL, S3, GCS, ABFS]: - StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False) - StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False) +StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index e61350a7ed..af39e9e852 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,4 +1,4 @@ -from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider def test_get_random_remote_path(): @@ -14,3 +14,10 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True + + +def test_lister(): + x = DataPersistencePlugins.supported_protocols() + main_protocols = {"file", "/", "gs", "http", "https", "s3"} + all_protocols = set([y.replace("://", "") for y in x]) + assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 773d30b5ae..d78d129309 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,9 +1,13 @@ import tempfile import typing +import pandas as pd +import pyarrow as pa import pytest +from typing_extensions import Annotated import flytekit.configuration +from flytekit import kwtypes, task from flytekit.configuration import Image, ImageConfig from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider @@ -11,18 +15,7 @@ from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - -import pandas as pd -import pyarrow as pa - -from flytekit import kwtypes, task from flytekit.types.structured.structured_dataset import ( - LOCAL, PARQUET, StructuredDataset, StructuredDatasetDecoder, @@ -49,7 +42,7 @@ def test_protocol(): assert protocol_prefix("s3://my-s3-bucket/file") == "s3" - assert protocol_prefix("/file") == "/" + assert protocol_prefix("/file") == "file" def generate_pandas() -> pd.DataFrame: @@ -121,10 +114,10 @@ def test_types_sd(): def test_retrieving(): - assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "/", PARQUET) is not None + assert StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", PARQUET) is not None with pytest.raises(ValueError): # We don't have a default "" format encoder - StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "/", "") + StructuredDatasetTransformerEngine.get_encoder(pd.DataFrame, "file", "") class TempEncoder(StructuredDatasetEncoder): def __init__(self, protocol): @@ -137,6 +130,11 @@ def encode(self): with pytest.raises(ValueError): StructuredDatasetTransformerEngine.register(TempEncoder("gs://"), default_for_type=False) + with pytest.raises(ValueError, match="Use None instead"): + e = TempEncoder("") + e._protocol = "" + StructuredDatasetTransformerEngine.register(e) + class TempEncoder: pass @@ -209,6 +207,24 @@ def encode( assert res is empty_format_temp_encoder +def test_slash_register(): + class TempEncoder(StructuredDatasetEncoder): + def __init__(self, fmt: str): + super().__init__(MyDF, None, supported_format=fmt) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + return literals.StructuredDataset(uri="") + + # Check that registering with a / triggers the file protocol instead. + StructuredDatasetTransformerEngine.register(TempEncoder("/")) + assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None + + def test_sd(): sd = StructuredDataset(dataframe="hi") sd.uri = "my uri" @@ -273,6 +289,9 @@ def test_convert_schema_type_to_structured_dataset_type(): with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): convert_schema_type_to_structured_dataset_type(int) + with pytest.raises(AssertionError, match="Unrecognized SchemaColumnType"): + convert_schema_type_to_structured_dataset_type(20) + def test_to_python_value_with_incoming_columns(): # make a literal with a type that has two columns @@ -338,7 +357,7 @@ def test_to_python_value_without_incoming_columns(): def test_format_correct(): class TempEncoder(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, LOCAL, "avro") + super().__init__(pd.DataFrame, "/", "avro") def encode( self, @@ -385,7 +404,7 @@ def test_protocol_detection(): e = StructuredDatasetTransformerEngine() ctx = FlyteContextManager.current_context() protocol = e._protocol_from_type_or_prefix(ctx, pd.DataFrame) - assert protocol == "/" + assert protocol == "file" with tempfile.TemporaryDirectory() as tmp_dir: fs = FileAccessProvider(local_sandbox_dir=tmp_dir, raw_output_prefix="s3://fdsa") diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index ada7483a0f..c7aa5563f9 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -13,6 +13,7 @@ StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, + StructuredDatasetTransformerEngine, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) @@ -23,8 +24,8 @@ def test_pandas(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - encoder = basic_dfs.PandasToParquetEncodingHandler("/") - decoder = basic_dfs.ParquetToPandasDecodingHandler("/") + encoder = basic_dfs.PandasToParquetEncodingHandler() + decoder = basic_dfs.ParquetToPandasDecodingHandler() ctx = context_manager.FlyteContextManager.current_context() sd = StructuredDataset(dataframe=df) @@ -41,3 +42,13 @@ def test_base_isnt_instantiable(): with pytest.raises(TypeError): StructuredDatasetDecoder(pd.DataFrame, "", "") + + +def test_arrow(): + encoder = basic_dfs.ArrowToParquetEncodingHandler() + decoder = basic_dfs.ParquetToArrowDecodingHandler() + assert encoder.protocol is None + assert decoder.protocol is None + assert encoder.python_type is decoder.python_type + d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"] + assert d is not None diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index f0d58eb36d..c2394d7a7a 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -18,11 +18,8 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - BIGQUERY, DF, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -41,7 +38,7 @@ class MockBQEncodingHandlers(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, BIGQUERY, "") + super().__init__(pd.DataFrame, "bq", "") def encode( self, @@ -56,7 +53,7 @@ def encode( class MockBQDecodingHandlers(StructuredDatasetDecoder): def __init__(self): - super().__init__(pd.DataFrame, BIGQUERY, "") + super().__init__(pd.DataFrame, "bq", "") def decode( self, @@ -104,7 +101,7 @@ def decode( table = pq.read_table(local_dir) return table.to_pandas().to_numpy() - for protocol in [LOCAL, S3]: + for protocol in ["/", "s3"]: StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) From 6dffed9135f944a63844b1553e85fd46fe88089f Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Thu, 25 Aug 2022 01:22:05 -0400 Subject: [PATCH 09/27] handle ImportError and OSError in extras.pytorch (#1141) * handle ImportError and OSError in extras.pytorch Signed-off-by: Niels Bantilan * isolate exception to torch import Signed-off-by: Niels Bantilan Signed-off-by: Niels Bantilan --- flytekit/extras/pytorch/__init__.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/flytekit/extras/pytorch/__init__.py b/flytekit/extras/pytorch/__init__.py index ae077d9755..770fe11b73 100644 --- a/flytekit/extras/pytorch/__init__.py +++ b/flytekit/extras/pytorch/__init__.py @@ -11,10 +11,21 @@ """ from flytekit.loggers import logger +# TODO: abstract this out so that there's an established pattern for registering plugins +# that have soft dependencies try: + # isolate the exception to the torch import + import torch + + _torch_installed = True +except (ImportError, OSError): + _torch_installed = False + + +if _torch_installed: from .checkpoint import PyTorchCheckpoint, PyTorchCheckpointTransformer from .native import PyTorchModuleTransformer, PyTorchTensorTransformer -except ImportError: +else: logger.info( "We won't register PyTorchCheckpointTransformer, PyTorchTensorTransformer, and PyTorchModuleTransformer because torch is not installed." ) From 4bfd4527aa167997b4dfcb1b3b6f79e5bd949128 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 26 Aug 2022 14:16:31 +0800 Subject: [PATCH 10/27] Register dataframe renderers in structured dataset (#1140) * Register dataframe renderers in structured dataset Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix test Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/deck/renderer.py | 11 +++++++ flytekit/types/structured/basic_dfs.py | 5 +++ .../types/structured/structured_dataset.py | 31 ++++++------------- .../flytekitplugins/polars/sd_transformers.py | 13 ++++++++ .../tests/test_polars_plugin_sd.py | 16 ++++++---- .../flytekitplugins/spark/sd_transformers.py | 12 +++++++ .../unit/core/test_structured_dataset.py | 15 +++++++++ tests/flytekit/unit/deck/test_renderer.py | 11 ++++--- .../test_structured_dataset_workflow.py | 15 +++++++-- 9 files changed, 95 insertions(+), 34 deletions(-) diff --git a/flytekit/deck/renderer.py b/flytekit/deck/renderer.py index 8617ae4d12..0cf781d3da 100644 --- a/flytekit/deck/renderer.py +++ b/flytekit/deck/renderer.py @@ -1,6 +1,7 @@ from typing import Any, Optional import pandas +import pyarrow from typing_extensions import Protocol, runtime_checkable @@ -24,3 +25,13 @@ def __init__(self, max_rows: Optional[int] = None): def to_html(self, df: pandas.DataFrame) -> str: assert isinstance(df, pandas.DataFrame) return df.to_html(max_rows=self._max_rows) + + +class ArrowRenderer: + """ + Render a Arrow dataframe as an HTML table. + """ + + def to_html(self, df: pyarrow.Table) -> str: + assert isinstance(df, pyarrow.Table) + return df.to_string() diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 97964d0b63..71dff61c5e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -7,6 +7,8 @@ import pyarrow.parquet as pq from flytekit import FlyteContext +from flytekit.deck import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -103,3 +105,6 @@ def decode( StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler()) StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler()) + +StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer()) +StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer()) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index bfbc494bbb..bdad752b16 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import importlib import os import types import typing @@ -13,20 +12,14 @@ import numpy as _np import pandas as pd import pyarrow as pa - -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence - -if importlib.util.find_spec("pyspark") is not None: - import pyspark -if importlib.util.find_spec("polars") is not None: - import polars as pl - from dataclasses_json import config, dataclass_json from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.deck.renderer import Renderable from flytekit.loggers import logger from flytekit.models import literals from flytekit.models import types as type_models @@ -339,6 +332,7 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): DEFAULT_FORMATS: Dict[Type, str] = {} Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder] + Renderers: Dict[Type, Renderable] = {} @staticmethod def _finder(handler_map, df_type: Type, protocol: str, format: str): @@ -385,6 +379,10 @@ def __init__(self): # Instances of StructuredDataset opt-in to the ability of being cached. self._hash_overridable = True + @classmethod + def register_renderer(cls, python_type: Type, renderer: Renderable): + cls.Renderers[python_type] = renderer + @classmethod def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False): """ @@ -698,19 +696,10 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ else: df = python_val - if isinstance(df, pd.DataFrame): - return df.describe().to_html() - elif isinstance(df, pa.Table): - return df.to_string() - elif isinstance(df, _np.ndarray): - return pd.DataFrame(df).describe().to_html() - elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame): - return pd.DataFrame(df.schema, columns=["StructField"]).to_html() - elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame): - describe_df = df.describe() - return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + if type(df) in self.Renderers: + return self.Renderers[type(df)].to_html(df) else: - raise NotImplementedError("Conversion to html string should be implemented") + raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}") def open_as( self, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 6388bc4c9e..0dfd0c6516 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd import polars as pl from flytekit import FlyteContext @@ -15,6 +16,17 @@ ) +class PolarsDataFrameRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, df: pl.DataFrame) -> str: + assert isinstance(df, pl.DataFrame) + describe_df = df.describe() + return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False) + + class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pl.DataFrame, None, PARQUET) @@ -61,3 +73,4 @@ def decode( StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(pl.DataFrame, PolarsDataFrameRenderer()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index 3c9c2613ae..b991cd5d13 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -1,10 +1,7 @@ -import flytekitplugins.polars # noqa F401 +import pandas as pd import polars as pl - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated +from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer +from typing_extensions import Annotated from flytekit import kwtypes, task, workflow from flytekit.types.structured.structured_dataset import PARQUET, StructuredDataset @@ -62,3 +59,10 @@ def wf() -> full_schema: result = wf() assert result is not None + + +def test_polars_renderer(): + df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) + assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( + df.describe().transpose(), columns=df.describe().columns + ).to_html(index=False) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 1a89b7b331..46079f40dd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,5 +1,6 @@ import typing +import pandas as pd from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -15,6 +16,16 @@ ) +class SparkDataFrameRenderer: + """ + Render a Spark dataframe schema as an HTML table. + """ + + def to_html(self, df: DataFrame) -> str: + assert isinstance(df, DataFrame) + return pd.DataFrame(df.schema, columns=["StructField"]).to_html() + + class SparkToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(DataFrame, None, PARQUET) @@ -50,3 +61,4 @@ def decode( StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer()) diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index d78d129309..7793df430f 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -414,3 +414,18 @@ def test_protocol_detection(): protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo") assert protocol == "bq" + + +def test_register_renderers(): + class DummyRenderer: + def to_html(self, input: str) -> str: + return "hello " + input + + renderers = StructuredDatasetTransformerEngine.Renderers + StructuredDatasetTransformerEngine.register_renderer(str, DummyRenderer()) + assert renderers[str].to_html("flyte") == "hello flyte" + assert pd.DataFrame in renderers + assert pa.Table in renderers + + with pytest.raises(NotImplementedError, match="Could not find a renderer for in"): + StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int) diff --git a/tests/flytekit/unit/deck/test_renderer.py b/tests/flytekit/unit/deck/test_renderer.py index f1ebbcd873..3f597af416 100644 --- a/tests/flytekit/unit/deck/test_renderer.py +++ b/tests/flytekit/unit/deck/test_renderer.py @@ -1,9 +1,12 @@ import pandas as pd +import pyarrow as pa -from flytekit.deck.renderer import TopFrameRenderer +from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer -def test_frame_profiling_renderer(): +def test_renderer(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) - renderer = TopFrameRenderer() - assert renderer.to_html(df) == df.to_html() + pa_df = pa.Table.from_pandas(df) + + assert TopFrameRenderer().to_html(df) == df.to_html() + assert ArrowRenderer().to_html(pa_df) == pa_df.to_string() diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index c2394d7a7a..c849995b96 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -68,6 +68,15 @@ def decode( StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True) +class NumpyRenderer: + """ + The Polars DataFrame summary statistics are rendered as an HTML table. + """ + + def to_html(self, array: np.ndarray) -> str: + return pd.DataFrame(array).describe().to_html() + + @pytest.fixture(autouse=True) def numpy_type(): class NumpyEncodingHandlers(StructuredDatasetEncoder): @@ -101,9 +110,9 @@ def decode( table = pq.read_table(local_dir) return table.to_pandas().to_numpy() - for protocol in ["/", "s3"]: - StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET)) - StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET)) + StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray)) + StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer()) @task From d65ce81850801e21ce9a0f4f70a2adf50f7f59b4 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 29 Aug 2022 23:50:43 +0800 Subject: [PATCH 11/27] pyflyte run imperative workflows (#1131) Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/constants.py | 2 + flytekit/clis/sdk_in_container/run.py | 29 +++++++++----- flytekit/core/tracker.py | 8 +++- flytekit/remote/remote.py | 7 +++- flytekit/tools/script_mode.py | 12 +++--- .../unit/cli/pyflyte/imperative_wf.py | 39 +++++++++++++++++++ tests/flytekit/unit/cli/pyflyte/test_run.py | 11 ++++++ 7 files changed, 90 insertions(+), 18 deletions(-) create mode 100644 tests/flytekit/unit/cli/pyflyte/imperative_wf.py diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index d0d7f7a229..46513553b9 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -7,6 +7,8 @@ CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" CTX_CONFIG_FILE = "config_file" +CTX_PROJECT_ROOT = "project_root" +CTX_MODULE = "module" project_option = _click.option( diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 95533fb4d5..935cfc1ad3 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -15,11 +15,17 @@ from typing_extensions import get_args from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE, CTX_DOMAIN, CTX_PROJECT +from flytekit.clis.sdk_in_container.constants import ( + CTX_CONFIG_FILE, + CTX_DOMAIN, + CTX_MODULE, + CTX_PROJECT, + CTX_PROJECT_ROOT, +) from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY, get_and_save_remote_with_click_context from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages -from flytekit.core import context_manager, tracker +from flytekit.core import context_manager from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContext from flytekit.core.data_persistence import FileAccessProvider @@ -480,14 +486,12 @@ def get_entities_in_file(filename: str) -> Entities: workflows = [] tasks = [] module = importlib.import_module(module_name) - for k in dir(module): - o = module.__dict__[k] - if isinstance(o, PythonFunctionWorkflow): - _, _, fn, _ = tracker.extract_task_module(o) - workflows.append(fn) + for name in dir(module): + o = module.__dict__[name] + if isinstance(o, WorkflowBase): + workflows.append(name) elif isinstance(o, PythonTask): - _, _, fn, _ = tracker.extract_task_module(o) - tasks.append(fn) + tasks.append(name) return Entities(workflows, tasks) @@ -542,6 +546,8 @@ def _run(*args, **kwargs): domain=domain, image_config=image_config, destination_dir=run_level_params.get("destination_dir"), + source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), + module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), ) options = None @@ -602,11 +608,16 @@ def get_command(self, ctx, exe_entity): ) project_root = _find_project_root(self._filename) + # Find the relative path for the filename relative to the root of the project. # N.B.: by construction project_root will necessarily be an ancestor of the filename passed in as # a parameter. rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") + + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root + ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module + entity = load_naive_entity(module, exe_entity, project_root) # If this is a remote execution, which we should know at this point, then create the remote object diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 0fad8335c2..9851e2e98b 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -4,6 +4,7 @@ import inspect as _inspect import os import typing +from types import ModuleType from typing import Callable, Tuple, Union from flytekit.configuration.feature_flags import FeatureFlags @@ -239,6 +240,11 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, if mod_name == "__main__": return name, "", name, os.path.abspath(inspect.getfile(f)) + mod_name = get_full_module_path(mod, mod_name) + return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + + +def get_full_module_path(mod: ModuleType, mod_name: str) -> str: if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != ".": package_root = ( FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT if FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT != "auto" else None @@ -247,4 +253,4 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, # We only replace the mod_name if it is more specific, else we already have a fully resolved path if len(new_mod_name) > len(mod_name): mod_name = new_mod_name - return f"{mod_name}.{name}", mod_name, name, os.path.abspath(inspect.getfile(mod)) + return mod_name diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 99f54b7933..f02226decc 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -577,6 +577,8 @@ def register_script( destination_dir: str = ".", default_launch_plan: typing.Optional[bool] = True, options: typing.Optional[Options] = None, + source_path: typing.Optional[str] = None, + module_name: typing.Optional[str] = None, ) -> typing.Union[FlyteWorkflow, FlyteTask]: """ Use this method to register a workflow via script mode. @@ -588,13 +590,16 @@ def register_script( :param entity: The workflow to be registered or the task to be registered :param default_launch_plan: This should be true if a default launch plan should be created for the workflow :param options: Additional execution options that can be configured for the default launchplan + :param source_path: The root of the project path + :param module_name: the name of the module :return: """ if image_config is None: image_config = ImageConfig.auto_default_image() upload_location, md5_bytes = fast_register_single_script( - entity, + source_path, + module_name, functools.partial( self.client.get_upload_signed_url, project=project or self.default_project, diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index f837447637..29b617824c 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,5 +1,6 @@ import gzip import hashlib +import importlib import os import shutil import tarfile @@ -10,8 +11,7 @@ from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 from flytekit.core import context_manager -from flytekit.core.tracker import extract_task_module -from flytekit.core.workflow import WorkflowBase +from flytekit.core.tracker import get_full_module_path def compress_single_script(source_path: str, destination: str, full_module_name: str): @@ -97,16 +97,14 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: def fast_register_single_script( - wf_entity: WorkflowBase, create_upload_location_fn: typing.Callable + source_path: str, module_name: str, create_upload_location_fn: typing.Callable ) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - _, mod_name, _, script_full_path = extract_task_module(wf_entity) - # Find project root by moving up the folder hierarchy until you cannot find a __init__.py file. - source_path = _find_project_root(script_full_path) # Open a temp directory and dump the contents of the digest. with tempfile.TemporaryDirectory() as tmp_dir: archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - compress_single_script(source_path, archive_fname, mod_name) + mod = importlib.import_module(module_name) + compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) flyte_ctx = context_manager.FlyteContextManager.current_context() md5, _ = hash_file(archive_fname) diff --git a/tests/flytekit/unit/cli/pyflyte/imperative_wf.py b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py new file mode 100644 index 0000000000..12d7f2e3a3 --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/imperative_wf.py @@ -0,0 +1,39 @@ +import typing + +from flytekit import Workflow, task + + +@task +def t1(a: str) -> str: + return a + " world" + + +@task +def t2(): + print("side effect") + + +@task +def t3(a: typing.List[str]) -> str: + return ",".join(a) + + +wf = Workflow(name="my.imperative.workflow.example") +wf.add_workflow_input("in1", str) +node_t1 = wf.add_entity(t1, a=wf.inputs["in1"]) +wf.add_workflow_output("output_from_t1", node_t1.outputs["o0"]) +wf.add_entity(t2) + +wf_in2 = wf.add_workflow_input("in2", str) +node_t3 = wf.add_entity(t3, a=[wf.inputs["in1"], wf_in2]) + +wf.add_workflow_output( + "output_list", + [node_t1.outputs["o0"], node_t3.outputs["o0"]], + python_type=typing.List[str], +) + + +if __name__ == "__main__": + print(wf) + print(wf(in1="hello", in2="foo")) diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 9d09d58ee8..4c652cbaab 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -19,6 +19,7 @@ from flytekit.core.task import task WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") +IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -30,6 +31,16 @@ def test_pyflyte_run_wf(): assert result.exit_code == 0 +def test_imperative_wf(): + runner = CliRunner() + result = runner.invoke( + pyflyte.main, + ["run", IMPERATIVE_WORKFLOW_FILE, "wf", "--in1", "hello", "--in2", "world"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + def test_pyflyte_run_cli(): runner = CliRunner() result = runner.invoke( From 6330278cddf844cb44db10ec8dcf8147ca289c1d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 31 Aug 2022 02:12:30 +0800 Subject: [PATCH 12/27] Using sidecar handler to run Papermill task (#1143) * remove nb prefix Signed-off-by: Kevin Su * add tests Signed-off-by: Kevin Su * Update requirements.in Signed-off-by: Kevin Su * remove _ Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- .../flytekit-papermill/dev-requirements.in | 4 +- .../flytekit-papermill/dev-requirements.txt | 49 ++++++++++++++++--- .../flytekitplugins/papermill/task.py | 29 +++++++++-- plugins/flytekit-papermill/tests/test_task.py | 37 ++++++++++++++ 4 files changed, 108 insertions(+), 11 deletions(-) diff --git a/plugins/flytekit-papermill/dev-requirements.in b/plugins/flytekit-papermill/dev-requirements.in index a57b6365fe..15889bb4ce 100644 --- a/plugins/flytekit-papermill/dev-requirements.in +++ b/plugins/flytekit-papermill/dev-requirements.in @@ -1,3 +1,3 @@ flyteidl>=1.0.0 -git+https://github.com/flyteorg/flytekit@sd-data-persistence#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark -# vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte +flytekitplugins-pod==v1.2.0b0 +git+https://github.com/flyteorg/flytekit@master#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index ba2d48ab1c..378ba8e17c 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with python 3.8 +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # # pip-compile dev-requirements.in @@ -8,8 +8,12 @@ arrow==1.2.1 # via jinja2-time binaryornot==0.4.4 # via cookiecutter +cachetools==5.2.0 + # via google-auth certifi==2021.10.8 - # via requests + # via + # kubernetes + # requests chardet==4.0.0 # via binaryornot charset-normalizer==2.0.10 @@ -43,9 +47,15 @@ flyteidl==1.0.0.post1 # -r dev-requirements.in # flytekit flytekit==1.1.0b0 - # via flytekitplugins-spark -flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@sd-data-persistence#subdirectory=plugins/flytekit-spark + # via + # flytekitplugins-pod + # flytekitplugins-spark +flytekitplugins-pod==v1.2.0b0 + # via -r dev-requirements.in +flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@master#subdirectory=plugins/flytekit-spark # via -r dev-requirements.in +google-auth==2.11.0 + # via kubernetes googleapis-common-protos==1.55.0 # via # flyteidl @@ -68,6 +78,8 @@ jinja2-time==0.2.0 # via cookiecutter keyring==23.5.0 # via flytekit +kubernetes==24.2.0 + # via flytekitplugins-pod markupsafe==2.0.1 # via jinja2 marshmallow==3.14.1 @@ -87,6 +99,8 @@ numpy==1.22.1 # via # pandas # pyarrow +oauthlib==3.2.0 + # via requests-oauthlib pandas==1.3.5 # via flytekit poyo==0.5.0 @@ -106,6 +120,12 @@ py4j==0.10.9.3 # via pyspark pyarrow==6.0.1 # via flytekit +pyasn1==0.4.8 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.2.8 + # via google-auth pyspark==3.2.1 # via flytekitplugins-spark python-dateutil==2.8.1 @@ -113,6 +133,7 @@ python-dateutil==2.8.1 # arrow # croniter # flytekit + # kubernetes # pandas python-json-logger==2.0.2 # via flytekit @@ -125,7 +146,9 @@ pytz==2021.3 # flytekit # pandas pyyaml==6.0 - # via flytekit + # via + # flytekit + # kubernetes regex==2021.11.10 # via docker-image-py requests==2.27.1 @@ -133,15 +156,23 @@ requests==2.27.1 # cookiecutter # docker # flytekit + # kubernetes + # requests-oauthlib # responses +requests-oauthlib==1.3.1 + # via kubernetes responses==0.17.0 # via flytekit retry==0.9.2 # via flytekit +rsa==4.9 + # via google-auth six==1.16.0 # via # cookiecutter + # google-auth # grpcio + # kubernetes # python-dateutil # responses sortedcontainers==2.4.0 @@ -159,10 +190,13 @@ typing-inspect==0.7.1 urllib3==1.26.8 # via # flytekit + # kubernetes # requests # responses websocket-client==1.3.2 - # via docker + # via + # docker + # kubernetes wheel==0.37.1 # via flytekit wrapt==1.13.3 @@ -171,3 +205,6 @@ wrapt==1.13.3 # flytekit zipp==3.7.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 304932a828..0721c39a37 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -10,10 +10,12 @@ from nbconvert import HTMLExporter from flytekit import FlyteContext, PythonInstanceTask +from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.deck.deck import Deck from flytekit.extend import Interface, TaskPlugins, TypeEngine from flytekit.loggers import logger +from flytekit.models import task as task_models from flytekit.models.literals import LiteralMap from flytekit.types.file import HTMLPage, PythonNotebook @@ -123,12 +125,13 @@ def __init__( # errors. # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) - self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) + self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func, **kwargs) # Rename the internal task so that there are no conflicts at serialization time. Technically these internal # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities # at serialization time. self._config_task_instance._name = f"{PAPERMILL_TASK_PREFIX}.{name}" - task_type = f"nb-{self._config_task_instance.task_type}" + task_type = f"{self._config_task_instance.task_type}" + task_type_version = self._config_task_instance.task_type_version self._notebook_path = os.path.abspath(notebook_path) self._render_deck = render_deck @@ -144,7 +147,12 @@ def __init__( } ) super().__init__( - name, task_config, task_type=task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs + name, + task_config, + task_type=task_type, + task_type_version=task_type_version, + interface=Interface(inputs=inputs, outputs=outputs), + **kwargs, ) @property @@ -159,6 +167,21 @@ def output_notebook_path(self) -> str: def rendered_output_path(self) -> str: return self._notebook_path.split(".ipynb")[0] + "-out.html" + def get_container(self, settings: SerializationSettings) -> task_models.Container: + return self._config_task_instance.get_container(settings) + + def get_k8s_pod(self, settings: SerializationSettings) -> task_models.K8sPod: + # The task name in original command is incorrect because we use _dummy_task_func to construct the _config_task_instance. + # Therefore, Here we replace primary container's command with NotebookTask's command. + def fn(settings: SerializationSettings) -> typing.List[str]: + return self.get_command(settings) + + self._config_task_instance.set_command_fn(fn) + return self._config_task_instance.get_k8s_pod(settings) + + def get_config(self, settings: SerializationSettings) -> typing.Dict[str, str]: + return self._config_task_instance.get_config(settings) + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: return self._config_task_instance.pre_execute(user_params) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index d60e68cdb0..4b456f7fdc 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -2,8 +2,12 @@ import os from flytekitplugins.papermill import NotebookTask +from flytekitplugins.pod import Pod +from kubernetes.client import V1Container, V1PodSpec +import flytekit from flytekit import kwtypes +from flytekit.configuration import Image, ImageConfig from flytekit.types.file import PythonNotebook from .testdata.datatype import X @@ -83,3 +87,36 @@ def test_notebook_deck_local_execution_doesnt_fail(): sqr, out, render = nb.execute(pi=4) # This is largely a no assert test to ensure render_deck never inhibits local execution. assert nb._render_deck, "Passing render deck to init should result in private attribute being set" + + +def generate_por_spec_for_task(): + primary_container = V1Container(name="primary") + pod_spec = V1PodSpec(containers=[primary_container]) + + return pod_spec + + +nb = NotebookTask( + name="test", + task_config=Pod(pod_spec=generate_por_spec_for_task(), primary_container_name="primary"), + notebook_path=_get_nb_path("nb-simple", abs=False), + inputs=kwtypes(h=str, n=int, w=str), + outputs=kwtypes(h=str, w=PythonNotebook, x=X), +) + + +def test_notebook_pod_task(): + serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + ) + + assert nb.get_container(serialization_settings) is None + assert nb.get_config(serialization_settings)["primary_container_name"] == "primary" + assert ( + nb.get_command(serialization_settings) + == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] + ) From a528754298ab9a1b8bc8ebde5671cad2eb739c06 Mon Sep 17 00:00:00 2001 From: Rahul Mehta <98349643+rahul-theorem@users.noreply.github.com> Date: Wed, 31 Aug 2022 12:37:45 -0400 Subject: [PATCH 13/27] Properly raise error in NumpyArrayTransformer (#1146) Signed-off-by: Rahul Mehta Signed-off-by: Rahul Mehta --- flytekit/types/numpy/ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index cb1cf2a900..b4f67b94f1 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -52,7 +52,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: try: uri = lv.scalar.blob.uri except AttributeError: - TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") local_path = ctx.file_access.get_random_local_path() ctx.file_access.get_data(uri, local_path, is_multipart=False) From cc46dda9ecd912704d79ae528491a857ea8ebe1b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 8 Sep 2022 14:15:16 +0800 Subject: [PATCH 14/27] Add assert_type in dataclass transformer (#1149) * Add assert_type in dataclassTransformer Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su * fix lint Signed-off-by: Kevin Su * Add one more test Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 51 ++++++++++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 38 +++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b3851e77ce..e9ee44b44c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -270,6 +270,46 @@ class Test(): def __init__(self): super().__init__("Object-Dataclass-Transformer", object) + def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): + # Skip iterating all attributes in the dataclass if the type of v already matches the expected_type + if type(v) == expected_type: + return + + # @dataclass_json + # @dataclass + # class Foo(object): + # a: int = 0 + # + # @task + # def t1(a: Foo): + # ... + # + # In above example, the type of v may not equal to the expected_type in some cases + # For example, + # 1. The input of t1 is another dataclass (bar), then we should raise an error + # 2. when using flyte remote to execute the above task, the expected_type is guess_python_type (FooSchema) by default. + # However, FooSchema is created by flytekit and it's not equal to the user-defined dataclass (Foo). + # Therefore, we should iterate all attributes in the dataclass and check the type of value in dataclass matches the expected_type. + + expected_fields_dict = {} + for f in dataclasses.fields(expected_type): + expected_fields_dict[f.name] = f.type + + for f in dataclasses.fields(type(v)): + original_type = f.type + expected_type = expected_fields_dict[f.name] + + if UnionTransformer.is_optional_type(original_type): + original_type = UnionTransformer.get_sub_type_in_optional(original_type) + if UnionTransformer.is_optional_type(expected_type): + expected_type = UnionTransformer.get_sub_type_in_optional(expected_type) + + val = v.__getattribute__(f.name) + if dataclasses.is_dataclass(val): + self.assert_type(expected_type, val) + elif original_type != expected_type: + raise TypeTransformerFailedError(f"Type of Val '{original_type}' is not an instance of {expected_type}") + def get_literal_type(self, t: Type[T]) -> LiteralType: """ Extracts the Literal type definition for a Dataclass and returns a type Struct. @@ -975,6 +1015,17 @@ class UnionTransformer(TypeTransformer[T]): def __init__(self): super().__init__("Typed Union", typing.Union) + @staticmethod + def is_optional_type(t: Type[T]) -> bool: + return get_origin(t) is typing.Union and type(None) in get_args(t) + + @staticmethod + def get_sub_type_in_optional(t: Type[T]) -> Type[T]: + """ + Return the generic Type T of the Optional type + """ + return get_args(t)[0] + def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: if get_origin(t) is Annotated: t = get_args(t)[0] diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 4b1c02134c..960fcd05a6 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -35,6 +35,7 @@ TypeEngine, TypeTransformer, TypeTransformerFailedError, + UnionTransformer, convert_json_schema_to_python_class, dataclass_from_dict, ) @@ -782,6 +783,43 @@ def test_union_type(): assert v == "hello" +def test_assert_dataclass_type(): + @dataclass_json + @dataclass + class Args(object): + x: int + y: typing.Optional[str] + + @dataclass_json + @dataclass + class Schema(object): + x: typing.Optional[Args] = None + + pt = Schema + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = Schema(x=Args(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(Schema, pv) + + @dataclass_json + @dataclass + class Bar(object): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, match="Type of Val '' is not an instance of " + ): + DataclassTransformer().assert_type(gt, pv) + + +def test_union_transformer(): + assert UnionTransformer.is_optional_type(typing.Optional[int]) + assert not UnionTransformer.is_optional_type(str) + assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int + + def test_union_type_with_annotated(): pt = typing.Union[ Annotated[str, FlyteAnnotation({"hello": "world"})], Annotated[int, FlyteAnnotation({"test": 123})] From 9f71392b9edcea3f7bae56d2831d652b66ad7bcd Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 8 Sep 2022 14:20:49 +0800 Subject: [PATCH 15/27] Pickle in Union Type (#1147) * Pickel in Union type Signed-off-by: Kevin Su * Pickel in Union type Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * update tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * Address comment Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/core/interface.py | 20 +++++++++++++------ tests/flytekit/unit/core/test_flyte_pickle.py | 19 +++++++++++++++++- .../test_structured_dataset_workflow.py | 9 ++------- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 0f651410bf..c721e2e160 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -7,7 +7,7 @@ from collections import OrderedDict from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union -from typing_extensions import get_args, get_origin, get_type_hints +from typing_extensions import Annotated, get_args, get_origin, get_type_hints from flytekit.core import context_manager from flytekit.core.docstring import Docstring @@ -259,12 +259,16 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: def _change_unrecognized_type_to_pickle(t: Type[T]) -> Type[T]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if t.__origin__ == list: + if get_origin(t) is list: return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif t.__origin__ == dict and t.__args__[0] == str: + elif get_origin(t) is dict and t.__args__[0] == str: return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - else: - TypeEngine.get_transformer(t) + elif get_origin(t) is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] + elif get_origin(t) is Annotated: + base_type, *config = get_args(t) + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + TypeEngine.get_transformer(t) except ValueError: logger.warning( f"Unsupported Type {t} found, Flyte will default to use PickleFile as the transport. " @@ -329,7 +333,11 @@ def transform_variable_map( elif v.__origin__ is dict: sub_type = v.__args__[1] if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} + if hasattr(sub_type.python_type(), "__name__"): + res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} + elif hasattr(sub_type.python_type(), "_name"): + # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. + res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} return res diff --git a/tests/flytekit/unit/core/test_flyte_pickle.py b/tests/flytekit/unit/core/test_flyte_pickle.py index a3ec6d17ce..318a6b76f3 100644 --- a/tests/flytekit/unit/core/test_flyte_pickle.py +++ b/tests/flytekit/unit/core/test_flyte_pickle.py @@ -1,5 +1,10 @@ from collections import OrderedDict -from typing import Dict, List +from collections.abc import Sequence +from typing import Dict, List, Union + +import numpy as np +import pandas as pd +from typing_extensions import Annotated import flytekit.configuration from flytekit.configuration import Image, ImageConfig @@ -80,3 +85,15 @@ def t1(a: int) -> List[Dict[str, Foo]]: task_spec.template.interface.outputs["o0"].type.collection_type.map_value_type.blob.format is FlytePickleTransformer.PYTHON_PICKLE_FORMAT ) + + +def test_union(): + @task + def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotation"]): + print(data) + + task_spec = get_serializable(OrderedDict(), serialization_settings, t1) + variants = task_spec.template.interface.inputs["data"].type.union_type.variants + assert variants[0].blob.format == "NumpyArray" + assert variants[1].structured_dataset_type.format == "parquet" + assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index c849995b96..5d04a12e7b 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -1,17 +1,12 @@ import os import typing -import pytest - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - import numpy as np import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +import pytest +from typing_extensions import Annotated from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow from flytekit.models import literals From 13da3b9da34c736149269efdb5d67d59db69f141 Mon Sep 17 00:00:00 2001 From: Rahul Mehta <98349643+rahul-theorem@users.noreply.github.com> Date: Thu, 8 Sep 2022 16:46:55 -0400 Subject: [PATCH 16/27] Bump max docker version to 7.0.0 (#1138) Signed-off-by: Rahul Mehta Signed-off-by: Rahul Mehta --- dev-requirements.txt | 43 ++++++++----------- doc-requirements.txt | 23 +++++++--- requirements-spark2.txt | 27 +++++------- requirements.txt | 27 +++++------- setup.py | 2 +- .../workflows/requirements.txt | 24 ++++------- 6 files changed, 65 insertions(+), 81 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index b477f2553b..13f2006571 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -8,6 +8,8 @@ # via # -c requirements.txt # pytest-flyte +appnope==0.1.3 + # via ipython arrow==1.2.2 # via # -c requirements.txt @@ -20,7 +22,7 @@ attrs==20.3.0 # pytest-docker backcall==0.2.0 # via ipython -bcrypt==3.2.2 +bcrypt==4.0.0 # via paramiko binaryornot==0.4.4 # via @@ -37,7 +39,6 @@ certifi==2022.6.15 cffi==1.15.1 # via # -c requirements.txt - # bcrypt # cryptography # pynacl cfgv==3.3.1 @@ -46,7 +47,7 @@ chardet==5.0.0 # via # -c requirements.txt # binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via # -c requirements.txt # requests @@ -59,7 +60,7 @@ cloudpickle==2.1.0 # via # -c requirements.txt # flytekit -codespell==2.1.0 +codespell==2.2.1 # via -r dev-requirements.in cookiecutter==2.1.1 # via @@ -78,7 +79,6 @@ cryptography==37.0.4 # -c requirements.txt # paramiko # pyopenssl - # secretstorage dataclasses-json==0.5.7 # via # -c requirements.txt @@ -96,11 +96,11 @@ diskcache==5.4.0 # via # -c requirements.txt # flytekit -distlib==0.3.5 +distlib==0.3.6 # via virtualenv distro==1.7.0 # via docker-compose -docker[ssh]==5.0.3 +docker[ssh]==6.0.0 # via # -c requirements.txt # docker-compose @@ -130,11 +130,11 @@ google-api-core[grpc]==2.8.2 # google-cloud-bigquery # google-cloud-bigquery-storage # google-cloud-core -google-auth==2.10.0 +google-auth==2.11.0 # via # google-api-core # google-cloud-core -google-cloud-bigquery==3.3.1 +google-cloud-bigquery==3.3.2 # via -r dev-requirements.in google-cloud-bigquery-storage==2.14.2 # via @@ -187,11 +187,6 @@ ipython==7.34.0 # via -r dev-requirements.in jedi==0.18.1 # via ipython -jeepney==0.8.0 - # via - # -c requirements.txt - # keyring - # secretstorage jinja2==3.1.2 # via # -c requirements.txt @@ -219,7 +214,7 @@ markupsafe==2.1.1 # via # -c requirements.txt # jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # -c requirements.txt # dataclasses-json @@ -233,7 +228,7 @@ marshmallow-jsonschema==0.13.0 # via # -c requirements.txt # flytekit -matplotlib-inline==0.1.5 +matplotlib-inline==0.1.6 # via ipython mock==4.0.3 # via -r dev-requirements.in @@ -259,6 +254,7 @@ numpy==1.21.6 packaging==21.3 # via # -c requirements.txt + # docker # google-cloud-bigquery # marshmallow # pytest @@ -282,7 +278,7 @@ pre-commit==2.20.0 # via -r dev-requirements.in prompt-toolkit==3.0.30 # via ipython -proto-plus==1.22.0 +proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -386,7 +382,7 @@ pyyaml==5.4.1 # docker-compose # flytekit # pre-commit -regex==2022.7.25 +regex==2022.8.17 # via # -c requirements.txt # docker-image-py @@ -410,10 +406,6 @@ retry==0.9.2 # flytekit rsa==4.9 # via google-auth -secretstorage==3.3.3 - # via - # -c requirements.txt - # keyring singledispatchmethod==1.0 # via # -c requirements.txt @@ -467,17 +459,18 @@ typing-extensions==4.3.0 # responses # torch # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via # -c requirements.txt # dataclasses-json -urllib3==1.26.11 +urllib3==1.26.12 # via # -c requirements.txt + # docker # flytekit # requests # responses -virtualenv==20.16.3 +virtualenv==20.16.4 # via pre-commit wcwidth==0.2.5 # via prompt-toolkit diff --git a/doc-requirements.txt b/doc-requirements.txt index 5bfb8c5d31..6febe9d508 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -20,6 +20,8 @@ arrow==1.2.2 # via jinja2-time astroid==2.12.2 # via sphinx-autoapi +astunparse==1.6.3 + # via tensorflow attrs==22.1.0 # via # jsonschema @@ -42,7 +44,7 @@ binaryornot==0.4.4 # via cookiecutter bleach==5.0.1 # via nbconvert -botocore==1.27.53 +botocore==1.27.63 # via -r doc-requirements.in cachetools==5.2.0 # via google-auth @@ -98,7 +100,7 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit @@ -124,7 +126,7 @@ flyteidl==1.1.12 # via flytekit fonttools==4.35.0 # via matplotlib -fsspec==2022.7.1 +fsspec==2022.8.0 # via # -r doc-requirements.in # modin @@ -192,7 +194,7 @@ importlib-metadata==4.12.0 # sqlalchemy importlib-resources==5.9.0 # via jsonschema -ipykernel==6.15.1 +ipykernel==6.15.2 # via # ipywidgets # jupyter @@ -354,7 +356,6 @@ notebook==6.4.12 # via # great-expectations # jupyter - # widgetsnbextension numpy==1.21.6 # via # altair @@ -377,6 +378,7 @@ oauthlib==3.2.0 # via requests-oauthlib packaging==21.3 # via + # docker # google-cloud-bigquery # great-expectations # ipykernel @@ -387,6 +389,7 @@ packaging==21.3 # pandera # qtpy # sphinx + # tensorflow pandas==1.3.5 # via # altair @@ -430,7 +433,7 @@ prompt-toolkit==3.0.30 # via # ipython # jupyter-console -proto-plus==1.22.0 +proto-plus==1.22.1 # via # google-cloud-bigquery # google-cloud-bigquery-storage @@ -470,7 +473,7 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pydantic==1.9.2 +pydantic==1.10.0 # via # pandas-profiling # pandera @@ -570,6 +573,8 @@ ruamel-yaml==0.17.17 # via great-expectations ruamel-yaml-clib==0.2.6 # via ruamel-yaml +scikit-learn==1.0.2 + # via skl2onnx scipy==1.7.3 # via # great-expectations @@ -590,6 +595,7 @@ singledispatchmethod==1.0 # via flytekit six==1.16.0 # via + # astunparse # bleach # google-auth # grpcio @@ -711,9 +717,11 @@ typing-extensions==4.3.0 # importlib-metadata # jsonschema # kiwisolver + # onnx # pandera # pydantic # responses + # tensorflow # torch # typing-inspect typing-inspect==0.7.1 @@ -731,6 +739,7 @@ unidecode==1.3.4 urllib3==1.26.11 # via # botocore + # docker # flytekit # great-expectations # kubernetes diff --git a/requirements-spark2.txt b/requirements-spark2.txt index ded88f1bb2..6543d204fd 100644 --- a/requirements-spark2.txt +++ b/requirements-spark2.txt @@ -22,7 +22,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -35,9 +35,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -46,7 +44,7 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit @@ -72,10 +70,6 @@ importlib-metadata==4.12.0 # flytekit # jsonschema # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -90,7 +84,7 @@ keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -110,7 +104,9 @@ numpy==1.21.6 # pandas # pyarrow packaging==21.3 - # via marshmallow + # via + # docker + # marshmallow pandas==1.3.5 # via # -r requirements.in @@ -157,7 +153,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.25 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -169,8 +165,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.3 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -192,10 +186,11 @@ typing-extensions==4.3.0 # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.11 +urllib3==1.26.12 # via + # docker # flytekit # requests # responses diff --git a/requirements.txt b/requirements.txt index 17a6487f7c..32b4ae49a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -33,9 +33,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl dataclasses-json==0.5.7 # via flytekit decorator==5.1.1 @@ -44,7 +42,7 @@ deprecated==1.2.13 # via flytekit diskcache==5.4.0 # via flytekit -docker==5.0.3 +docker==6.0.0 # via flytekit docker-image-py==0.1.12 # via flytekit @@ -70,10 +68,6 @@ importlib-metadata==4.12.0 # flytekit # jsonschema # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -88,7 +82,7 @@ keyring==23.8.2 # via flytekit markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -108,7 +102,9 @@ numpy==1.21.6 # pandas # pyarrow packaging==21.3 - # via marshmallow + # via + # docker + # marshmallow pandas==1.3.5 # via # -r requirements.in @@ -155,7 +151,7 @@ pyyaml==5.4.1 # -r requirements.in # cookiecutter # flytekit -regex==2022.7.25 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -167,8 +163,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.3 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -190,10 +184,11 @@ typing-extensions==4.3.0 # importlib-metadata # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.11 +urllib3==1.26.12 # via + # docker # flytekit # requests # responses diff --git a/setup.py b/setup.py index 56e2dd5624..e5bc3cfa33 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "click>=6.6,<9.0", "croniter>=0.3.20,<4.0.0", "deprecated>=1.0,<2.0", - "docker>=5.0.3,<6.0.0", + "docker>=5.0.3,<7.0.0", "python-dateutil>=2.1", "grpcio>=1.43.0,!=1.45.0,<2.0", "grpcio-status>=1.43,!=1.45.0", diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 0fc659135f..57e35e64b3 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -14,7 +14,7 @@ cffi==1.15.1 # via cryptography chardet==5.0.0 # via binaryornot -charset-normalizer==2.1.0 +charset-normalizer==2.1.1 # via requests click==8.1.3 # via @@ -27,9 +27,7 @@ cookiecutter==2.1.1 croniter==1.3.5 # via flytekit cryptography==37.0.4 - # via - # pyopenssl - # secretstorage + # via pyopenssl cycler==0.11.0 # via matplotlib dataclasses-json==0.5.7 @@ -50,7 +48,7 @@ flyteidl==1.1.12 # via flytekit flytekit==1.1.1 # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.35.0 +fonttools==4.37.1 # via matplotlib googleapis-common-protos==1.56.4 # via @@ -69,10 +67,6 @@ importlib-metadata==4.12.0 # click # flytekit # keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # cookiecutter @@ -87,7 +81,7 @@ kiwisolver==1.4.4 # via matplotlib markupsafe==2.1.1 # via jinja2 -marshmallow==3.17.0 +marshmallow==3.17.1 # via # dataclasses-json # marshmallow-enum @@ -161,7 +155,7 @@ pyyaml==6.0 # via # cookiecutter # flytekit -regex==2022.7.25 +regex==2022.8.17 # via docker-image-py requests==2.28.1 # via @@ -173,8 +167,6 @@ responses==0.21.0 # via flytekit retry==0.9.2 # via flytekit -secretstorage==3.3.3 - # via keyring singledispatchmethod==1.0 # via flytekit six==1.16.0 @@ -195,14 +187,14 @@ typing-extensions==4.3.0 # kiwisolver # responses # typing-inspect -typing-inspect==0.7.1 +typing-inspect==0.8.0 # via dataclasses-json -urllib3==1.26.11 +urllib3==1.26.12 # via # flytekit # requests # responses -websocket-client==1.3.3 +websocket-client==1.4.0 # via docker wheel==0.37.1 # via From 2aaaeaa49ff0e96f76b8d0f233fe5837db14aad7 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Thu, 8 Sep 2022 14:28:55 -0700 Subject: [PATCH 17/27] Set flytekit<2.0 in plugins (#1152) Signed-off-by: Eduardo Apolinario Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- plugins/flytekit-aws-athena/setup.py | 2 +- plugins/flytekit-aws-batch/setup.py | 2 +- plugins/flytekit-aws-sagemaker/setup.py | 2 +- plugins/flytekit-bigquery/setup.py | 2 +- plugins/flytekit-data-fsspec/setup.py | 2 +- plugins/flytekit-deck-standard/setup.py | 2 +- plugins/flytekit-dolt/setup.py | 2 +- plugins/flytekit-greatexpectations/setup.py | 2 +- plugins/flytekit-hive/setup.py | 2 +- plugins/flytekit-k8s-pod/setup.py | 2 +- plugins/flytekit-kf-mpi/setup.py | 2 +- plugins/flytekit-kf-pytorch/setup.py | 2 +- plugins/flytekit-kf-tensorflow/setup.py | 2 +- plugins/flytekit-modin/setup.py | 2 +- plugins/flytekit-onnx-pytorch/setup.py | 2 +- plugins/flytekit-onnx-scikitlearn/setup.py | 2 +- plugins/flytekit-onnx-tensorflow/setup.py | 2 +- plugins/flytekit-pandera/setup.py | 2 +- plugins/flytekit-papermill/setup.py | 2 +- plugins/flytekit-polars/setup.py | 2 +- plugins/flytekit-snowflake/setup.py | 2 +- plugins/flytekit-spark/setup.py | 2 +- plugins/flytekit-sqlalchemy/setup.py | 2 +- 23 files changed, 23 insertions(+), 23 deletions(-) diff --git a/plugins/flytekit-aws-athena/setup.py b/plugins/flytekit-aws-athena/setup.py index 1164b99d00..0cea449a97 100644 --- a/plugins/flytekit-aws-athena/setup.py +++ b/plugins/flytekit-aws-athena/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-batch/setup.py b/plugins/flytekit-aws-batch/setup.py index e176e35aae..68ad62750c 100644 --- a/plugins/flytekit-aws-batch/setup.py +++ b/plugins/flytekit-aws-batch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-aws-sagemaker/setup.py b/plugins/flytekit-aws-sagemaker/setup.py index 2781192f41..76a816fe06 100644 --- a/plugins/flytekit-aws-sagemaker/setup.py +++ b/plugins/flytekit-aws-sagemaker/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "sagemaker-training>=3.6.2,<4.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "sagemaker-training>=3.6.2,<4.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-bigquery/setup.py b/plugins/flytekit-bigquery/setup.py index e84cd6ce2b..0e7eed5d9d 100644 --- a/plugins/flytekit-bigquery/setup.py +++ b/plugins/flytekit-bigquery/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "google-cloud-bigquery"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "google-cloud-bigquery"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index 3756b9228b..f7d03690a8 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "fsspec>=2021.7.0", "botocore>=1.7.48", "pandas>=1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "fsspec>=2021.7.0", "botocore>=1.7.48", "pandas>=1.2.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-deck-standard/setup.py b/plugins/flytekit-deck-standard/setup.py index fe04ab5434..a47bf0f0d0 100644 --- a/plugins/flytekit-deck-standard/setup.py +++ b/plugins/flytekit-deck-standard/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}-standard" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "markdown", "plotly", "pandas_profiling"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "markdown", "plotly", "pandas_profiling"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-dolt/setup.py b/plugins/flytekit-dolt/setup.py index bb8b572bc7..ce6abbc64b 100644 --- a/plugins/flytekit-dolt/setup.py +++ b/plugins/flytekit-dolt/setup.py @@ -6,7 +6,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "dolt_integrations>=0.1.5"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "dolt_integrations>=0.1.5"] dev_requires = ["pytest-mock>=3.6.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-greatexpectations/setup.py b/plugins/flytekit-greatexpectations/setup.py index e50f707624..93c73d5416 100644 --- a/plugins/flytekit-greatexpectations/setup.py +++ b/plugins/flytekit-greatexpectations/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "great-expectations>=0.13.30", "sqlalchemy>=1.4.23"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "great-expectations>=0.13.30", "sqlalchemy>=1.4.23"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-hive/setup.py b/plugins/flytekit-hive/setup.py index f9602500f8..a2f67d982f 100644 --- a/plugins/flytekit-hive/setup.py +++ b/plugins/flytekit-hive/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-k8s-pod/setup.py b/plugins/flytekit-k8s-pod/setup.py index 01704e1a6a..29c56922b5 100644 --- a/plugins/flytekit-k8s-pod/setup.py +++ b/plugins/flytekit-k8s-pod/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "kubernetes>=12.0.1", ] diff --git a/plugins/flytekit-kf-mpi/setup.py b/plugins/flytekit-kf-mpi/setup.py index 18f168af18..c8a845fb13 100644 --- a/plugins/flytekit-kf-mpi/setup.py +++ b/plugins/flytekit-kf-mpi/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "flyteidl>=0.21.4"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "flyteidl>=0.21.4"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-pytorch/setup.py b/plugins/flytekit-kf-pytorch/setup.py index 2e0b57a7f8..dc10722bd9 100644 --- a/plugins/flytekit-kf-pytorch/setup.py +++ b/plugins/flytekit-kf-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-kf-tensorflow/setup.py b/plugins/flytekit-kf-tensorflow/setup.py index 45d8fe6b2e..5ec98ea74b 100644 --- a/plugins/flytekit-kf-tensorflow/setup.py +++ b/plugins/flytekit-kf-tensorflow/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" # TODO: Requirements are missing, add them back in later. -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-modin/setup.py b/plugins/flytekit-modin/setup.py index 777a19db47..46c5dbc02e 100644 --- a/plugins/flytekit-modin/setup.py +++ b/plugins/flytekit-modin/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "modin>=0.13.0", "fsspec", "ray", diff --git a/plugins/flytekit-onnx-pytorch/setup.py b/plugins/flytekit-onnx-pytorch/setup.py index 74e3b940ec..0642054565 100644 --- a/plugins/flytekit-onnx-pytorch/setup.py +++ b/plugins/flytekit-onnx-pytorch/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "torch>=1.11.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "torch>=1.11.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-onnx-scikitlearn/setup.py b/plugins/flytekit-onnx-scikitlearn/setup.py index 9815bedaf2..46a2bceaf7 100644 --- a/plugins/flytekit-onnx-scikitlearn/setup.py +++ b/plugins/flytekit-onnx-scikitlearn/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "skl2onnx>=1.10.3"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "skl2onnx>=1.10.3"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-onnx-tensorflow/setup.py b/plugins/flytekit-onnx-tensorflow/setup.py index d2865b083d..53d35e7fbd 100644 --- a/plugins/flytekit-onnx-tensorflow/setup.py +++ b/plugins/flytekit-onnx-tensorflow/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.0.0b0,<1.2.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "tf2onnx>=1.9.3", "tensorflow>=2.7.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-pandera/setup.py b/plugins/flytekit-pandera/setup.py index cbe9bf4061..0625c138d2 100644 --- a/plugins/flytekit-pandera/setup.py +++ b/plugins/flytekit-pandera/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "pandera>=0.7.1"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pandera>=0.7.1"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-papermill/setup.py b/plugins/flytekit-papermill/setup.py index 26a3f1b705..46d6296f55 100644 --- a/plugins/flytekit-papermill/setup.py +++ b/plugins/flytekit-papermill/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", diff --git a/plugins/flytekit-polars/setup.py b/plugins/flytekit-polars/setup.py index ea3feb8582..f4086babb7 100644 --- a/plugins/flytekit-polars/setup.py +++ b/plugins/flytekit-polars/setup.py @@ -5,7 +5,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" plugin_requires = [ - "flytekit>=1.1.0b0,<1.2.0", + "flytekit>=1.1.0b0,<2.0.0", "polars>=0.8.27", ] diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index ebea48f304..e82bb7268f 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 9fdffe6c22..108fbb1169 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "pyspark>=3.0.0"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "pyspark>=3.0.0"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-sqlalchemy/setup.py b/plugins/flytekit-sqlalchemy/setup.py index 6bf0a8e1ab..aa13aa8fbc 100644 --- a/plugins/flytekit-sqlalchemy/setup.py +++ b/plugins/flytekit-sqlalchemy/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.1.0b0,<1.2.0", "sqlalchemy>=1.4.7"] +plugin_requires = ["flytekit>=1.1.0b0,<2.0.0", "sqlalchemy>=1.4.7"] __version__ = "0.0.0+develop" From 1d1fc852190a06a5204ff32f329ad647ff82f9d3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 10 Sep 2022 02:58:19 +0800 Subject: [PATCH 18/27] Add literal type to union literal (#1144) * Add literal type to union literal Signed-off-by: Kevin Su * fix test Signed-off-by: Kevin Su * Add tests Signed-off-by: Kevin Su * more tests Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/clis/sdk_in_container/run.py | 5 +- tests/flytekit/unit/cli/pyflyte/test_run.py | 64 ++++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 935cfc1ad3..d0b890ba7b 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -33,7 +33,7 @@ from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase from flytekit.models import literals from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, Primitive +from flytekit.models.literals import Blob, BlobMetadata, Primitive, Union from flytekit.models.types import LiteralType, SimpleType from flytekit.remote.executions import FlyteWorkflowExecution from flytekit.tools import module_loader, script_mode @@ -270,8 +270,7 @@ def convert_to_union( # and then use flyte converter to convert it to literal. python_val = converter._click_type.convert(value, param, ctx) literal = converter.convert_to_literal(ctx, param, python_val) - self._python_type = python_type - return literal + return Literal(scalar=Scalar(union=Union(literal, variant))) except (Exception or AttributeError) as e: logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 4c652cbaab..b6bb3d44c4 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,10 +1,15 @@ +import functools import os import pathlib +import typing +from enum import Enum +import click import mock import pytest from click.testing import CliRunner +from flytekit import FlyteContextManager from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY @@ -12,11 +17,15 @@ REMOTE_FLAG_KEY, RUN_LEVEL_PARAMS_KEY, FileParamType, + FlyteLiteralConverter, get_entities_in_file, run_command, ) -from flytekit.configuration import Image, ImageConfig +from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task +from flytekit.core.type_engine import TypeEngine +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteRemote WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") IMPERATIVE_WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "imperative_wf.py") @@ -267,3 +276,56 @@ def test_file_param(): assert l.local r = FileParamType().convert("https://tmp/file", m, m) assert r.local is False + + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +@pytest.mark.parametrize( + "python_type, python_value", + [ + (typing.Union[typing.List[int], str, Color], "flyte"), + (typing.Union[typing.List[int], str, Color], "red"), + (typing.Union[typing.List[int], str, Color], [1, 2, 3]), + (typing.List[int], [1, 2, 3]), + (typing.Dict[str, int], {"flyte": 2}), + ], +) +def test_literal_converter(python_type, python_value): + get_upload_url_fn = functools.partial( + FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d" + ) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(python_type) + + lc = FlyteLiteralConverter( + click_ctx, ctx, literal_type=lt, python_type=python_type, get_upload_url_fn=get_upload_url_fn + ) + + assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) + + +def test_enum_converter(): + pt = typing.Union[str, Color] + + get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(pt) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple == SimpleType.STRING + assert union_lt.stored_type.enum_type is None + + pt = typing.Union[Color, str] + lt = TypeEngine.to_literal_type(typing.Union[Color, str]) + lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple is None + assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] From e70c6789922e82179b395f60752e14fa2cafc17d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 10 Sep 2022 03:22:31 +0800 Subject: [PATCH 19/27] Fix the type of optional[int] in nested dataclass (#1148) * Fix the type of optional[int] in nested dataclass Signed-off-by: Kevin Su * update tests Signed-off-by: Kevin Su * update comments Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 9 ++++++ tests/flytekit/unit/core/test_type_engine.py | 31 +++++++++++++------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e9ee44b44c..aa7b261e64 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -489,6 +489,8 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> return python_val def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: + if val is None: + return val if t == int: return int(val) @@ -501,6 +503,13 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + if get_origin(t) is typing.Union and type(None) in get_args(t): + # Handle optional type. e.g. Optional[int], Optional[dataclass] + # Marshmallow doesn't support union type, so the type here is always an optional type. + # https://github.com/marshmallow-code/marshmallow/issues/1191#issuecomment-480831796 + # Note: Union[None, int] is also an optional type, but Marshmallow does not support it. + return self._fix_val_int(get_args(t)[0], val) + if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 960fcd05a6..c70adb85d0 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -148,6 +148,7 @@ def test_list_of_dataclass_getting_python_value(): @dataclass_json @dataclass() class Bar(object): + v: typing.Union[int, None] w: typing.Optional[str] x: float y: str @@ -161,7 +162,7 @@ class Foo(object): y: typing.Dict[str, str] z: Bar - foo = Foo(w=1, x=[1], y={"hello": "10"}, z=Bar(w=None, x=1.0, y="hello", z={"world": False})) + foo = Foo(u=5, v=None, w=1, x=[1], y={"hello": "10"}, z=Bar(v=3, w=None, x=1.0, y="hello", z={"world": False})) generic = _json_format.Parse(typing.cast(DataClassJsonMixin, foo).to_json(), _struct.Struct()) lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) @@ -171,16 +172,24 @@ class Foo(object): schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") - pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) - assert isinstance(pv, list) - assert pv[0].w == foo.w - assert pv[0].x == foo.x - assert pv[0].y == foo.y - assert pv[0].z.x == foo.z.x - assert type(pv[0].z.x) == float - assert pv[0].z.y == foo.z.y - assert pv[0].z.z == foo.z.z - assert foo == dataclass_from_dict(Foo, asdict(pv[0])) + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) def test_file_no_downloader_default(): From f5f5ab7ee02ad66ae9f6e7f932e5fb25f5b754ff Mon Sep 17 00:00:00 2001 From: Vanshika Chowdhary Date: Fri, 9 Sep 2022 16:33:41 -0700 Subject: [PATCH 20/27] Added symlink dereferencing in fast packaging and tests (#1151) * Added symlink dereferencing and tests Signed-off-by: Vanshika Chowdhary * Added flag to register as well Signed-off-by: Vanshika Chowdhary * More flag propagation Signed-off-by: Vanshika Chowdhary Signed-off-by: Vanshika Chowdhary Co-authored-by: Vanshika Chowdhary --- flytekit/clis/sdk_in_container/package.py | 12 +++++++-- flytekit/clis/sdk_in_container/register.py | 9 ++++++- flytekit/clis/sdk_in_container/serialize.py | 10 +++++-- flytekit/tools/fast_registration.py | 5 ++-- flytekit/tools/repo.py | 7 +++-- .../unit/tools/test_fast_registration.py | 26 ++++++++++++++++++- 6 files changed, 59 insertions(+), 10 deletions(-) diff --git a/flytekit/clis/sdk_in_container/package.py b/flytekit/clis/sdk_in_container/package.py index 2a884e29da..1a849d0681 100644 --- a/flytekit/clis/sdk_in_container/package.py +++ b/flytekit/clis/sdk_in_container/package.py @@ -77,8 +77,16 @@ default="/root", help="Filesystem path to where the code is copied into within the Dockerfile. look for `COPY . /root` like command.", ) +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.pass_context -def package(ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter): +def package( + ctx, image_config, source, output, force, fast, in_container_source_path, python_interpreter, deref_symlinks +): """ This command produces a Flyte backend registrable package of all entities in Flyte. For tasks, one pb file is produced for each task, representing one TaskTemplate object. @@ -103,6 +111,6 @@ def package(ctx, image_config, source, output, force, fast, in_container_source_ display_help_with_error(ctx, "No packages to scan for flyte entities. Aborting!") try: - serialize_and_package(pkgs, serialization_settings, source, output, fast) + serialize_and_package(pkgs, serialization_settings, source, output, fast, deref_symlinks) except NoSerializableEntitiesError: click.secho(f"No flyte objects found in packages {pkgs}", fg="yellow") diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 03e00d7896..024b70edde 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -99,6 +99,12 @@ type=str, help="Version the package or module is registered with", ) +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -111,6 +117,7 @@ def register( service_account: str, raw_data_prefix: str, version: typing.Optional[str], + deref_symlinks: bool, package_or_module: typing.Tuple[str], ): """ @@ -142,7 +149,7 @@ def register( # Create a zip file containing all the entries. detected_root = find_common_root(package_or_module) cli_logger.debug(f"Using {detected_root} as root folder for project") - zip_file = fast_package(detected_root, output) + zip_file = fast_package(detected_root, output, deref_symlinks) # Upload zip file to Admin using FlyteRemote. md5_bytes, native_url = remote._upload_file(pathlib.Path(zip_file)) diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index 0b12d6b406..33c0b47940 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -155,16 +155,22 @@ def fast(ctx): @click.command("workflows") +@click.option( + "--deref-symlinks", + default=False, + is_flag=True, + help="Enables symlink dereferencing when packaging files in fast registration", +) @click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context -def fast_workflows(ctx, folder=None): +def fast_workflows(ctx, folder=None, deref_symlinks=False): if folder: click.echo(f"Writing output to {folder}") source_dir = ctx.obj[CTX_LOCAL_SRC_ROOT] # Write using gzip - archive_fname = fast_package(source_dir, folder) + archive_fname = fast_package(source_dir, folder, deref_symlinks) click.echo(f"Wrote compressed archive to {archive_fname}") pkgs = ctx.obj[CTX_PACKAGES] diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index c4ac31a01a..34faadc58c 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -21,12 +21,13 @@ file_access = FlyteContextManager.current_context().file_access -def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: +def fast_package(source: os.PathLike, output_dir: os.PathLike, deref_symlinks: bool = False) -> os.PathLike: """ Takes a source directory and packages everything not covered by common ignores into a tarball named after a hexdigest of the included files. :param os.PathLike source: :param os.PathLike output_dir: + :param bool deref_symlinks: Enables dereferencing symlinks when packaging directory :return os.PathLike: """ ignore = IgnoreGroup(source, [GitIgnore, DockerIgnore, StandardIgnore]) @@ -41,7 +42,7 @@ def fast_package(source: os.PathLike, output_dir: os.PathLike) -> os.PathLike: with tempfile.TemporaryDirectory() as tmp_dir: tar_path = os.path.join(tmp_dir, "tmp.tar") - with tarfile.open(tar_path, "w") as tar: + with tarfile.open(tar_path, "w", dereference=deref_symlinks) as tar: tar.add(source, arcname="", filter=lambda x: ignore.tar_filter(tar_strip_file_attributes(x))) with gzip.GzipFile(filename=archive_fname, mode="wb", mtime=0) as gzipped: with open(tar_path, "rb") as tar_file: diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 167c772184..ceaee36435 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -75,6 +75,7 @@ def package( source: str = ".", output: str = "./flyte-package.tgz", fast: bool = False, + deref_symlinks: bool = False, ): """ Package the given entities and the source code (if fast is enabled) into a package with the given name in output @@ -82,6 +83,7 @@ def package( :param source: source folder :param output: output package name with suffix :param fast: fast enabled implies source code is bundled + :param deref_symlinks: if enabled then symlinks are dereferenced during packaging """ if not registrable_entities: raise NoSerializableEntitiesError("Nothing to package") @@ -95,7 +97,7 @@ def package( if os.path.abspath(output).startswith(os.path.abspath(source)) and os.path.exists(output): click.secho(f"{output} already exists within {source}, deleting and re-creating it", fg="yellow") os.remove(output) - archive_fname = fast_registration.fast_package(source, output_tmpdir) + archive_fname = fast_registration.fast_package(source, output_tmpdir, deref_symlinks) click.secho(f"Fast mode enabled: compressed archive {archive_fname}", dim=True) with tarfile.open(output, "w:gz") as tar: @@ -110,13 +112,14 @@ def serialize_and_package( source: str = ".", output: str = "./flyte-package.tgz", fast: bool = False, + deref_symlinks: bool = False, options: typing.Optional[Options] = None, ): """ Fist serialize and then package all entities """ registrable_entities = serialize(pkgs, settings, source, options=options) - package(registrable_entities, source, output, fast) + package(registrable_entities, source, output, fast, deref_symlinks) def register( diff --git a/tests/flytekit/unit/tools/test_fast_registration.py b/tests/flytekit/unit/tools/test_fast_registration.py index 0b50d6fdcf..aae3995bcb 100644 --- a/tests/flytekit/unit/tools/test_fast_registration.py +++ b/tests/flytekit/unit/tools/test_fast_registration.py @@ -23,7 +23,10 @@ def flyte_project(tmp_path): "workflows": { "__pycache__": {"some.pyc": ""}, "hello_world.py": "print('Hello World!')", - } + }, + }, + "utils": { + "util.py": "print('Hello from utils!')", }, ".venv": {"lots": "", "of": "", "packages": ""}, ".env": "supersecret", @@ -35,6 +38,7 @@ def flyte_project(tmp_path): } make_tree(tmp_path, tree) + os.symlink(str(tmp_path) + "/utils/util.py", str(tmp_path) + "/src/util") subprocess.run(["git", "init", str(tmp_path)]) return tmp_path @@ -48,9 +52,29 @@ def test_package(flyte_project, tmp_path): ".gitignore", "keep.foo", "src", + "src/util", "src/workflows", "src/workflows/hello_world.py", + "utils", + "utils/util.py", + ] + util = tar.getmember("src/util") + assert util.issym() + assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) + assert str(archive_fname).endswith(FAST_FILEENDING) + + +def test_package_with_symlink(flyte_project, tmp_path): + archive_fname = fast_package(source=flyte_project / "src", output_dir=tmp_path, deref_symlinks=True) + with tarfile.open(archive_fname, dereference=True) as tar: + assert tar.getnames() == [ + "", # tar root, output removes leading '/' + "util", + "workflows", + "workflows/hello_world.py", ] + util = tar.getmember("util") + assert util.isfile() assert str(os.path.basename(archive_fname)).startswith(FAST_PREFIX) assert str(archive_fname).endswith(FAST_FILEENDING) From 0d2f7607cdc97d23cc4efc9c61f5893716804293 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Wed, 14 Sep 2022 17:33:37 -0700 Subject: [PATCH 21/27] Strip newline from client secret (#1163) * Strip newline from client secret * Add logging and rework the secret file comparison to work on windows Signed-off-by: Eduardo Apolinario Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/configuration/__init__.py | 9 ++++++++- .../configs/creds_secret_location.yaml | 2 +- .../unit/configuration/test_internal.py | 17 +++++++++++++++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index c0edf07ab2..047ce4b3d3 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -94,6 +94,7 @@ from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages from flytekit.configuration.file import ConfigEntry, ConfigFile, get_config_file, read_file_if_exists, set_if_exists +from flytekit.loggers import logger PROJECT_PLACEHOLDER = "{{ registration.project }}" DOMAIN_PLACEHOLDER = "{{ registration.domain }}" @@ -336,10 +337,16 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + client_credentials_secret = read_file_if_exists( + _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) + ) + if client_credentials_secret and client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", - read_file_if_exists(_internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file)), + client_credentials_secret, ) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml index 9c1ad83a3e..7da41b7c38 100644 --- a/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml +++ b/tests/flytekit/unit/configuration/configs/creds_secret_location.yaml @@ -1,7 +1,7 @@ admin: # For GRPC endpoints you might want to use dns:///flyte.myexample.com endpoint: dns:///flyte.mycorp.io - clientSecretLocation: ../tests/flytekit/unit/configuration/configs/fake_secret + clientSecretLocation: configs/fake_secret authType: Pkce insecure: true clientId: propeller diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 6ba81f309c..7f6be53a55 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import get_config_file, read_file_if_exists +from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -31,7 +31,20 @@ def test_client_secret_location(): os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_location.yaml") ) secret_location = Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(cfg) - assert secret_location == "../tests/flytekit/unit/configuration/configs/fake_secret" + assert secret_location == "configs/fake_secret" + + # Modify the path to the secret inline + cfg._yaml_config["admin"]["clientSecretLocation"] = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs/fake_secret" + ) + + # Assert secret contains a newline + with open(cfg._yaml_config["admin"]["clientSecretLocation"], "rb") as f: + assert f.read().decode().endswith("\n") is True + + # Assert that secret in platform config does not contain a newline + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "hello" def test_read_file_if_exists(): From c31e60e997b4ba49bab4e397f523aadf4da032eb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 18 Aug 2022 04:20:18 +0800 Subject: [PATCH 22/27] Fix the type of optional[int] in dataclass (#1135) Signed-off-by: Kevin Su --- flytekit/core/type_engine.py | 7 ++++++- tests/flytekit/unit/core/test_type_engine.py | 6 ++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index aa7b261e64..686b03bb8c 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -491,7 +491,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if val is None: return val - if t == int: + if t == int or t == typing.Optional[int]: return int(val) if isinstance(val, list): @@ -1369,6 +1369,11 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac def _get_element_type(element_property: typing.Dict[str, str]) -> Type[T]: element_type = element_property["type"] element_format = element_property["format"] if "format" in element_property else None + + if type(element_type) == list: + # Element type of Optional[int] is [integer, None] + return typing.Optional[_get_element_type({"type": element_type[0]})] + if element_type == "string": return str elif element_type == "integer": diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index c70adb85d0..34251f7d0a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -157,6 +157,8 @@ class Bar(object): @dataclass_json @dataclass() class Foo(object): + u: typing.Optional[int] + v: typing.Optional[int] w: int x: typing.List[int] y: typing.Dict[str, str] @@ -173,6 +175,7 @@ class Foo(object): foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + print("=====") pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) assert isinstance(guessed_pv, list) assert guessed_pv[0].u == pv[0].u @@ -184,9 +187,7 @@ class Foo(object): assert type(guessed_pv[0].u) == int assert guessed_pv[0].v is None assert type(guessed_pv[0].w) == int - assert type(guessed_pv[0].z.v) == int assert type(guessed_pv[0].z.x) == float - assert guessed_pv[0].z.v == pv[0].z.v assert guessed_pv[0].z.y == pv[0].z.y assert guessed_pv[0].z.z == pv[0].z.z assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) @@ -1221,6 +1222,7 @@ def test_pass_annotated_to_downstream_tasks(): """ Test to confirm that the loaded dataframe is not affected and can be used in @dynamic. """ + # pandas dataframe hash function def hash_pandas_dataframe(df: pd.DataFrame) -> str: return str(pd.util.hash_pandas_object(df)) From 4e0a73c7ae133d2f5789c1db55520811e2287b0f Mon Sep 17 00:00:00 2001 From: Snyk bot Date: Mon, 12 Sep 2022 18:52:23 +0100 Subject: [PATCH 23/27] fix: plugins/flytekit-papermill/dev-requirements.txt to reduce vulnerabilities (#1154) The following vulnerabilities are fixed by pinning transitive dependencies: - https://snyk.io/vuln/SNYK-PYTHON-OAUTHLIB-3021142 - https://snyk.io/vuln/SNYK-PYTHON-PYSPARK-3021131 Signed-off-by: Eduardo Apolinario --- plugins/flytekit-papermill/dev-requirements.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 378ba8e17c..933b2318d1 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -99,7 +99,7 @@ numpy==1.22.1 # via # pandas # pyarrow -oauthlib==3.2.0 +oauthlib==3.2.1 # via requests-oauthlib pandas==1.3.5 # via flytekit @@ -126,7 +126,9 @@ pyasn1==0.4.8 # rsa pyasn1-modules==0.2.8 # via google-auth -pyspark==3.2.1 +pycparser==2.21 + # via cffi +pyspark==3.3.0 # via flytekitplugins-spark python-dateutil==2.8.1 # via From ebdbdd9c58ce4cc0b78facf254ebfd6c5e6b85f4 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 31 Aug 2022 02:12:30 +0800 Subject: [PATCH 24/27] Using sidecar handler to run Papermill task (#1143) * remove nb prefix Signed-off-by: Kevin Su * add tests Signed-off-by: Kevin Su * Update requirements.in Signed-off-by: Kevin Su * remove _ Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- plugins/flytekit-papermill/dev-requirements.txt | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 933b2318d1..378ba8e17c 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -99,7 +99,7 @@ numpy==1.22.1 # via # pandas # pyarrow -oauthlib==3.2.1 +oauthlib==3.2.0 # via requests-oauthlib pandas==1.3.5 # via flytekit @@ -126,9 +126,7 @@ pyasn1==0.4.8 # rsa pyasn1-modules==0.2.8 # via google-auth -pycparser==2.21 - # via cffi -pyspark==3.3.0 +pyspark==3.2.1 # via flytekitplugins-spark python-dateutil==2.8.1 # via From 7e86e79ecb42ef0914f043b2df0a3f9ab067b720 Mon Sep 17 00:00:00 2001 From: Snyk bot Date: Tue, 30 Aug 2022 23:01:52 +0200 Subject: [PATCH 25/27] fix: plugins/flytekit-papermill/dev-requirements.txt to reduce vulnerabilities (#1145) The following vulnerabilities are fixed by pinning transitive dependencies: - https://snyk.io/vuln/SNYK-PYTHON-COOKIECUTTER-2414281 --- plugins/flytekit-papermill/dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 378ba8e17c..459044f746 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -24,7 +24,7 @@ click==7.1.2 # flytekit cloudpickle==2.0.0 # via flytekit -cookiecutter==1.7.3 +cookiecutter==2.1.1 # via flytekit croniter==1.2.0 # via flytekit From c6586fd23574612145701dfddf112e78be46cdc8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Aug 2022 14:04:05 -0700 Subject: [PATCH 26/27] Bump pyspark from 3.2.1 to 3.2.2 in /plugins/flytekit-papermill (#1130) Bumps [pyspark](https://github.com/apache/spark) from 3.2.1 to 3.2.2. - [Release notes](https://github.com/apache/spark/releases) - [Commits](https://github.com/apache/spark/compare/v3.2.1...v3.2.2) --- updated-dependencies: - dependency-name: pyspark dependency-type: indirect ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- plugins/flytekit-papermill/dev-requirements.txt | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 459044f746..01473c6a68 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -14,6 +14,8 @@ certifi==2021.10.8 # via # kubernetes # requests +cffi==1.15.1 + # via cryptography chardet==4.0.0 # via binaryornot charset-normalizer==2.0.10 @@ -28,6 +30,8 @@ cookiecutter==2.1.1 # via flytekit croniter==1.2.0 # via flytekit +cryptography==37.0.4 + # via secretstorage dataclasses-json==0.5.6 # via flytekit decorator==5.1.1 @@ -70,6 +74,10 @@ idna==3.3 # via requests importlib-metadata==4.10.1 # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage jinja2==3.0.3 # via # cookiecutter @@ -116,7 +124,7 @@ protoc-gen-swagger==0.1.0 # via flyteidl py==1.11.0 # via retry -py4j==0.10.9.3 +py4j==0.10.9.5 # via pyspark pyarrow==6.0.1 # via flytekit @@ -126,7 +134,9 @@ pyasn1==0.4.8 # rsa pyasn1-modules==0.2.8 # via google-auth -pyspark==3.2.1 +pycparser==2.21 + # via cffi +pyspark==3.2.2 # via flytekitplugins-spark python-dateutil==2.8.1 # via @@ -167,6 +177,8 @@ retry==0.9.2 # via flytekit rsa==4.9 # via google-auth +secretstorage==3.3.3 + # via keyring six==1.16.0 # via # cookiecutter From 5b3cca00c33281e6eb0dda65a1cc7049d11f05b4 Mon Sep 17 00:00:00 2001 From: Snyk bot Date: Mon, 12 Sep 2022 18:52:23 +0100 Subject: [PATCH 27/27] fix: plugins/flytekit-papermill/dev-requirements.txt to reduce vulnerabilities (#1154) The following vulnerabilities are fixed by pinning transitive dependencies: - https://snyk.io/vuln/SNYK-PYTHON-OAUTHLIB-3021142 - https://snyk.io/vuln/SNYK-PYTHON-PYSPARK-3021131 --- plugins/flytekit-papermill/dev-requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-papermill/dev-requirements.txt b/plugins/flytekit-papermill/dev-requirements.txt index 01473c6a68..c8294ca254 100644 --- a/plugins/flytekit-papermill/dev-requirements.txt +++ b/plugins/flytekit-papermill/dev-requirements.txt @@ -107,7 +107,7 @@ numpy==1.22.1 # via # pandas # pyarrow -oauthlib==3.2.0 +oauthlib==3.2.1 # via requests-oauthlib pandas==1.3.5 # via flytekit @@ -136,7 +136,7 @@ pyasn1-modules==0.2.8 # via google-auth pycparser==2.21 # via cffi -pyspark==3.2.2 +pyspark==3.3.0 # via flytekitplugins-spark python-dateutil==2.8.1 # via