From 10836be8d0e6ac4d69834141296bd3ab31884e8f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 11 Dec 2023 10:46:46 -0500 Subject: [PATCH 01/19] Enable flyte cli to be pluggable Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/helpers.py | 24 ++------- flytekit/clis/sdk_in_container/plugin.py | 53 +++++++++++++++++++ flytekit/clis/sdk_in_container/pyflyte.py | 3 ++ flytekit/clis/sdk_in_container/run.py | 5 +- plugins/flytekit-papermill/tests/test_task.py | 2 +- .../flytekit-sqlalchemy/tests/test_task.py | 2 +- .../unit/cli/pyflyte/test_backfill.py | 2 +- .../unit/cli/pyflyte/test_launchplan.py | 2 +- .../flytekit/unit/cli/pyflyte/test_plugin.py | 28 ++++++++++ .../unit/cli/pyflyte/test_register.py | 17 +++--- tests/flytekit/unit/cli/pyflyte/test_run.py | 2 +- 11 files changed, 104 insertions(+), 36 deletions(-) create mode 100644 flytekit/clis/sdk_in_container/plugin.py create mode 100644 tests/flytekit/unit/cli/pyflyte/test_plugin.py diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index e8df92f7ca..72ec2a7f39 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,34 +1,16 @@ -import typing from dataclasses import replace from typing import Optional import rich_click as click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.configuration import Config, ImageConfig, get_config_file -from flytekit.loggers import cli_logger +from flytekit.clis.sdk_in_container.plugin import cli_plugin +from flytekit.configuration import ImageConfig from flytekit.remote.remote import FlyteRemote FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" -def get_remote( - cfg_file_path: typing.Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None -) -> FlyteRemote: - cfg_file = get_config_file(cfg_file_path) - if cfg_file is None: - cfg_obj = Config.for_sandbox() - cli_logger.info("No config files found, creating remote with sandbox config") - else: - cfg_obj = Config.auto(cfg_file_path) - cli_logger.info( - f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_path}" if cfg_file_path else "") - ) - return FlyteRemote( - cfg_obj, default_project=project, default_domain=domain, data_upload_location=data_upload_location - ) - - def get_and_save_remote_with_click_context( ctx: click.Context, project: str, @@ -50,7 +32,7 @@ def get_and_save_remote_with_click_context( if ctx.obj.get(FLYTE_REMOTE_INSTANCE_KEY) is not None: return ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - r = get_remote(cfg_file_location, project, domain, data_upload_location) + r = cli_plugin.get_remote(cfg_file_location, project, domain, data_upload_location) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r diff --git a/flytekit/clis/sdk_in_container/plugin.py b/flytekit/clis/sdk_in_container/plugin.py new file mode 100644 index 0000000000..646c1791c0 --- /dev/null +++ b/flytekit/clis/sdk_in_container/plugin.py @@ -0,0 +1,53 @@ +from typing import Optional +import sys + + +from importlib_metadata import entry_points +from flytekit.configuration import Config, get_config_file +from flytekit.loggers import cli_logger +from flytekit.remote import FlyteRemote + + +class PyFlyteCLIPlugin: + @staticmethod + def get_remote( + config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None + ) -> FlyteRemote: + """Get FlyteRemote object for CLI session.""" + cfg_file = get_config_file(config) + if cfg_file is None: + cfg_obj = Config.for_sandbox() + cli_logger.info("No config files found, creating remote with sandbox config") + else: + cfg_obj = Config.auto(config) + cli_logger.info(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else "")) + return FlyteRemote( + cfg_obj, default_project=project, default_domain=domain, data_upload_location=data_upload_location + ) + + @staticmethod + def configure_pyflyte_cli(main): + """Configure pyflyte's CLI.""" + return main + + +def get_cli_plugin(): + """Get plugin for entrypoint.""" + cli_plugins = list(entry_points(group="flytekit.cli.plugin")) + + if not cli_plugins: + return PyFlyteCLIPlugin + + if len(cli_plugins) >= 2: + plugin_names = [p.name for p in cli_plugins] + cli_logger.info(f"Multiple plugins seen for flytekit.cli.plugin: {plugin_names}") + + cli_plugin_to_load = cli_plugins[0] + cli_logger.info(f"Loading plugin: {cli_plugin_to_load.name}") + return cli_plugin_to_load.load() + + +if "pytest" in sys.modules: + cli_plugin = PyFlyteCLIPlugin +else: + cli_plugin = get_cli_plugin() diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 890198a1ef..2eaea8ea9f 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -14,6 +14,7 @@ from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.metrics import metrics from flytekit.clis.sdk_in_container.package import package +from flytekit.clis.sdk_in_container.plugin import cli_plugin from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize @@ -88,5 +89,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): main.add_command(get) main.epilog +cli_plugin.configure_pyflyte_cli(main) + if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7c5e105409..325faaf0e6 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -14,7 +14,8 @@ from rich.progress import Progress from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal -from flytekit.clis.sdk_in_container.helpers import get_remote, patch_image_config +from flytekit.clis.sdk_in_container.plugin import cli_plugin +from flytekit.clis.sdk_in_container.helpers import patch_image_config from flytekit.clis.sdk_in_container.utils import ( PyFlyteParams, domain_option, @@ -261,7 +262,7 @@ def remote_instance(self) -> FlyteRemote: data_upload_location = None if self.is_remote: data_upload_location = remote_fs.REMOTE_PLACEHOLDER - self._remote = get_remote(self.config_file, self.project, self.domain, data_upload_location) + self._remote = cli_plugin.get_remote(self.config_file, self.project, self.domain, data_upload_location) return self._remote @property diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 3fe4e83c4b..47f8a63cf3 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -236,7 +236,7 @@ def wf(a: float) -> typing.List[float]: assert wf(a=3.14) == [9.8596, 9.8596] -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_notebook_task(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index c001f7b543..d450f76f3d 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -205,7 +205,7 @@ def test_task_serialization_deserialization_with_secret(sql_server): assert r.iat[0, 0] == 1 -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_sql_task(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py index 0fd328e638..9918e6233a 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_backfill.py +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -20,7 +20,7 @@ def test_resolve_backfill_window(): resolve_backfill_window() -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) def test_pyflyte_backfill(mock_remote): mock_remote.generate_console_url.return_value = "ex" runner = CliRunner() diff --git a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py index 1a461bfd35..4f707f640e 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py +++ b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py @@ -6,7 +6,7 @@ from flytekit.remote import FlyteRemote -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @pytest.mark.parametrize( ("action", "expected_state"), [ diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py new file mode 100644 index 0000000000..ccac1e99dd --- /dev/null +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -0,0 +1,28 @@ +from unittest.mock import patch, Mock + +from flytekit.clis.sdk_in_container.plugin import get_cli_plugin, PyFlyteCLIPlugin + + +@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +def test_get_plugin_default(entry_points): + entry_points.side_effect = lambda *args, **kwargs: [] + + default_plugin = get_cli_plugin() + assert default_plugin is PyFlyteCLIPlugin + + +@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +def test_get_plugin_load_other_plugin(entry_points, caplog): + loaded_plugin_1 = Mock() + entry_1 = Mock() + entry_1.name = "entry_1" + entry_1.load.side_effect = lambda: loaded_plugin_1 + + entry_2 = Mock() + entry_points.side_effect = lambda *args, **kwargs: [entry_1, entry_2] + + plugin = get_cli_plugin() + assert plugin is loaded_plugin_1 + + assert entry_1.load.call_count == 1 + assert entry_2.load.call_count == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index d9c9565ac4..b42325fc37 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,7 +8,8 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, get_remote +from flytekit.clis.sdk_in_container.plugin import PyFlyteCLIPlugin +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context from flytekit.configuration import Config from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR from flytekit.core import context_manager @@ -47,16 +48,16 @@ def reset_flytectl_config_env_var() -> pytest.fixture(): return os.environ[FLYTECTL_CONFIG_ENV_VAR] -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote") def test_get_remote(mock_remote, reset_flytectl_config_env_var): - r = get_remote(None, "p", "d") + r = PyFlyteCLIPlugin.get_remote(None, "p", "d") assert r is not None mock_remote.assert_called_once_with( Config.for_sandbox(), default_project="p", default_domain="d", data_upload_location=None ) -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote") def test_saving_remote(mock_remote): mock_context = mock.MagicMock mock_context.obj = {} @@ -78,7 +79,7 @@ def test_register_with_no_package_or_module_argument(): ) -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_with_no_output_dir_passed(mock_client, mock_remote): ctx = FlyteContextManager.current_context() @@ -100,7 +101,7 @@ def test_register_with_no_output_dir_passed(mock_client, mock_remote): shutil.rmtree("core1") -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_shell_task(mock_client, mock_remote): mock_remote._client = mock_client @@ -120,7 +121,7 @@ def test_register_shell_task(mock_client, mock_remote): shutil.rmtree("core2") -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_non_fast_register(mock_client, mock_remote): ctx = FlyteContextManager.current_context() @@ -140,7 +141,7 @@ def test_non_fast_register(mock_client, mock_remote): shutil.rmtree("core2") -@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 8090b6a95f..65ba549eb2 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -38,7 +38,7 @@ def remote(): ], ) def test_pyflyte_run_wf(remote, remote_flag): - with mock.patch("flytekit.clis.sdk_in_container.helpers.get_remote"): + with mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote"): runner = CliRunner() module_path = WORKFLOW_FILE result = runner.invoke( From 939b5178eb050ea052b2c14aa1f60f190bed356a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 11 Dec 2023 13:04:52 -0500 Subject: [PATCH 02/19] TST Adds unit test for running configure_pyflyte_cli Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/plugin.py | 5 ++- .../flytekit/unit/cli/pyflyte/test_plugin.py | 40 +++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/flytekit/clis/sdk_in_container/plugin.py b/flytekit/clis/sdk_in_container/plugin.py index 646c1791c0..03be0a5540 100644 --- a/flytekit/clis/sdk_in_container/plugin.py +++ b/flytekit/clis/sdk_in_container/plugin.py @@ -1,7 +1,7 @@ from typing import Optional import sys - +from click import Command from importlib_metadata import entry_points from flytekit.configuration import Config, get_config_file from flytekit.loggers import cli_logger @@ -26,7 +26,7 @@ def get_remote( ) @staticmethod - def configure_pyflyte_cli(main): + def configure_pyflyte_cli(main: Command): """Configure pyflyte's CLI.""" return main @@ -47,6 +47,7 @@ def get_cli_plugin(): return cli_plugin_to_load.load() +# Ensure that cli_plugin is always configured to PyFlyteCLIPlugin during pytest runs if "pytest" in sys.modules: cli_plugin = PyFlyteCLIPlugin else: diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index ccac1e99dd..1c526e03de 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -1,4 +1,5 @@ from unittest.mock import patch, Mock +import click from flytekit.clis.sdk_in_container.plugin import get_cli_plugin, PyFlyteCLIPlugin @@ -26,3 +27,42 @@ def test_get_plugin_load_other_plugin(entry_points, caplog): assert entry_1.load.call_count == 1 assert entry_2.load.call_count == 0 + + +class CustomPlugin(PyFlyteCLIPlugin): + @staticmethod + def configure_pyflyte_cli(main): + """Make config hidden in main CLI.""" + for p in main.params: + if p.name == "config": + p.hidden = True + return main + + +@click.command +@click.option( + "-c", + "--config", + required=False, + type=str, + help="Path to config file for use within container", +) +def click_main(config): + pass + + +@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +def test_get_plugin_custom(entry_points): + entry_1 = Mock() + entry_1.name.side_effect = "custom_plugin" + entry_1.load.side_effect = lambda: CustomPlugin + + entry_points.side_effect = lambda *args, **kwargs: [entry_1] + + plugin = get_cli_plugin() + assert plugin is CustomPlugin + + assert not click_main.params[0].hidden + + plugin.configure_pyflyte_cli(click_main) + assert click_main.params[0].hidden From 01bf753ca5c8c5e05d52f0d0ed45debd590547fe Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 11 Dec 2023 13:06:44 -0500 Subject: [PATCH 03/19] Adds type hints Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/clis/sdk_in_container/plugin.py b/flytekit/clis/sdk_in_container/plugin.py index 03be0a5540..03518fb761 100644 --- a/flytekit/clis/sdk_in_container/plugin.py +++ b/flytekit/clis/sdk_in_container/plugin.py @@ -26,7 +26,7 @@ def get_remote( ) @staticmethod - def configure_pyflyte_cli(main: Command): + def configure_pyflyte_cli(main: Command) -> Command: """Configure pyflyte's CLI.""" return main From df0e4552053dc45d39d12b824590496ef072b71f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 15:03:39 -0500 Subject: [PATCH 04/19] ENH Rename to flytekit plugin Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/helpers.py | 4 +-- flytekit/clis/sdk_in_container/pyflyte.py | 4 +-- flytekit/clis/sdk_in_container/run.py | 4 +-- .../plugin.py | 35 ++++++++++--------- .../flytekit/unit/cli/pyflyte/test_plugin.py | 15 ++++---- .../unit/cli/pyflyte/test_register.py | 4 +-- 6 files changed, 35 insertions(+), 31 deletions(-) rename flytekit/{clis/sdk_in_container => configuration}/plugin.py (62%) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 72ec2a7f39..ce1ba0cd37 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -4,7 +4,7 @@ import rich_click as click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.clis.sdk_in_container.plugin import cli_plugin +from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.configuration import ImageConfig from flytekit.remote.remote import FlyteRemote @@ -32,7 +32,7 @@ def get_and_save_remote_with_click_context( if ctx.obj.get(FLYTE_REMOTE_INSTANCE_KEY) is not None: return ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - r = cli_plugin.get_remote(cfg_file_location, project, domain, data_upload_location) + r = plugin.get_remote(cfg_file_location, project, domain, data_upload_location) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 2eaea8ea9f..d989448695 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -14,7 +14,7 @@ from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.metrics import metrics from flytekit.clis.sdk_in_container.package import package -from flytekit.clis.sdk_in_container.plugin import cli_plugin +from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize @@ -89,7 +89,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): main.add_command(get) main.epilog -cli_plugin.configure_pyflyte_cli(main) +plugin.configure_pyflyte_cli(main) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 325faaf0e6..2d0a33b3f9 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -14,8 +14,8 @@ from rich.progress import Progress from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal -from flytekit.clis.sdk_in_container.plugin import cli_plugin from flytekit.clis.sdk_in_container.helpers import patch_image_config +from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.clis.sdk_in_container.utils import ( PyFlyteParams, domain_option, @@ -262,7 +262,7 @@ def remote_instance(self) -> FlyteRemote: data_upload_location = None if self.is_remote: data_upload_location = remote_fs.REMOTE_PLACEHOLDER - self._remote = cli_plugin.get_remote(self.config_file, self.project, self.domain, data_upload_location) + self._remote = plugin.get_remote(self.config_file, self.project, self.domain, data_upload_location) return self._remote @property diff --git a/flytekit/clis/sdk_in_container/plugin.py b/flytekit/configuration/plugin.py similarity index 62% rename from flytekit/clis/sdk_in_container/plugin.py rename to flytekit/configuration/plugin.py index 03518fb761..d5dbabdaee 100644 --- a/flytekit/clis/sdk_in_container/plugin.py +++ b/flytekit/configuration/plugin.py @@ -1,14 +1,16 @@ -from typing import Optional +import os import sys +from typing import Optional from click import Command from importlib_metadata import entry_points + from flytekit.configuration import Config, get_config_file from flytekit.loggers import cli_logger from flytekit.remote import FlyteRemote -class PyFlyteCLIPlugin: +class FlytekitPlugin: @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None @@ -31,24 +33,25 @@ def configure_pyflyte_cli(main: Command) -> Command: return main -def get_cli_plugin(): +def get_plugin(): """Get plugin for entrypoint.""" - cli_plugins = list(entry_points(group="flytekit.cli.plugin")) + plugins = list(entry_points(group="flytekit.plugin")) - if not cli_plugins: - return PyFlyteCLIPlugin + if not plugins: + return FlytekitPlugin - if len(cli_plugins) >= 2: - plugin_names = [p.name for p in cli_plugins] - cli_logger.info(f"Multiple plugins seen for flytekit.cli.plugin: {plugin_names}") + if len(plugins) >= 2: + plugin_names = [p.name for p in plugins] + cli_logger.info(f"Multiple plugins seen for flytekit.plugin: {plugin_names}") - cli_plugin_to_load = cli_plugins[0] - cli_logger.info(f"Loading plugin: {cli_plugin_to_load.name}") - return cli_plugin_to_load.load() + plugin_to_load = plugins[0] + cli_logger.info(f"Loading plugin: {plugin_to_load.name}") + return plugin_to_load.load() -# Ensure that cli_plugin is always configured to PyFlyteCLIPlugin during pytest runs -if "pytest" in sys.modules: - cli_plugin = PyFlyteCLIPlugin +# Ensure that plugin is always configured to FlytekitPlugin during pytest runs +# Set USE_FLYTEKIT_PLUGIN=0 for testing other plugins +if "pytest" in sys.modules and os.environ.get("USE_FLYTEKIT_PLUGIN", "1") == "1": + plugin = FlytekitPlugin else: - cli_plugin = get_cli_plugin() + plugin = get_plugin() diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index 1c526e03de..2c84fe8442 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -1,15 +1,16 @@ -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch + import click -from flytekit.clis.sdk_in_container.plugin import get_cli_plugin, PyFlyteCLIPlugin +from flytekit.configuration.plugin import FlytekitPlugin, get_plugin @patch("flytekit.clis.sdk_in_container.plugin.entry_points") def test_get_plugin_default(entry_points): entry_points.side_effect = lambda *args, **kwargs: [] - default_plugin = get_cli_plugin() - assert default_plugin is PyFlyteCLIPlugin + default_plugin = get_plugin() + assert default_plugin is FlytekitPlugin @patch("flytekit.clis.sdk_in_container.plugin.entry_points") @@ -22,14 +23,14 @@ def test_get_plugin_load_other_plugin(entry_points, caplog): entry_2 = Mock() entry_points.side_effect = lambda *args, **kwargs: [entry_1, entry_2] - plugin = get_cli_plugin() + plugin = get_plugin() assert plugin is loaded_plugin_1 assert entry_1.load.call_count == 1 assert entry_2.load.call_count == 0 -class CustomPlugin(PyFlyteCLIPlugin): +class CustomPlugin(FlytekitPlugin): @staticmethod def configure_pyflyte_cli(main): """Make config hidden in main CLI.""" @@ -59,7 +60,7 @@ def test_get_plugin_custom(entry_points): entry_points.side_effect = lambda *args, **kwargs: [entry_1] - plugin = get_cli_plugin() + plugin = get_plugin() assert plugin is CustomPlugin assert not click_main.params[0].hidden diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index b42325fc37..f6668902fb 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -8,8 +8,8 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.plugin import PyFlyteCLIPlugin from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.plugin import FlytekitPlugin from flytekit.configuration import Config from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR from flytekit.core import context_manager @@ -50,7 +50,7 @@ def reset_flytectl_config_env_var() -> pytest.fixture(): @mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote") def test_get_remote(mock_remote, reset_flytectl_config_env_var): - r = PyFlyteCLIPlugin.get_remote(None, "p", "d") + r = FlytekitPlugin.get_remote(None, "p", "d") assert r is not None mock_remote.assert_called_once_with( Config.for_sandbox(), default_project="p", default_domain="d", data_upload_location=None From 08ab1228033a51c0edfa5702c07fe3f1a74e8e1e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 15:15:06 -0500 Subject: [PATCH 05/19] Migrate to plugin to configuration Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/helpers.py | 2 +- flytekit/clis/sdk_in_container/pyflyte.py | 2 +- flytekit/clis/sdk_in_container/run.py | 2 +- plugins/flytekit-papermill/tests/test_task.py | 2 +- plugins/flytekit-sqlalchemy/tests/test_task.py | 2 +- tests/flytekit/unit/cli/pyflyte/test_backfill.py | 2 +- tests/flytekit/unit/cli/pyflyte/test_launchplan.py | 2 +- tests/flytekit/unit/cli/pyflyte/test_plugin.py | 6 +++--- tests/flytekit/unit/cli/pyflyte/test_register.py | 14 +++++++------- tests/flytekit/unit/cli/pyflyte/test_run.py | 2 +- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index ce1ba0cd37..1f81e00521 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -4,8 +4,8 @@ import rich_click as click from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.configuration import ImageConfig +from flytekit.configuration.plugin import plugin from flytekit.remote.remote import FlyteRemote FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index d989448695..1a93bc820c 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -14,7 +14,6 @@ from flytekit.clis.sdk_in_container.local_cache import local_cache from flytekit.clis.sdk_in_container.metrics import metrics from flytekit.clis.sdk_in_container.package import package -from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize @@ -23,6 +22,7 @@ from flytekit.clis.version import info from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE from flytekit.configuration.internal import LocalSDK +from flytekit.configuration.plugin import plugin from flytekit.loggers import cli_logger diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 2d0a33b3f9..c64d329ba4 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -15,7 +15,6 @@ from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, Literal from flytekit.clis.sdk_in_container.helpers import patch_image_config -from flytekit.clis.sdk_in_container.plugin import plugin from flytekit.clis.sdk_in_container.utils import ( PyFlyteParams, domain_option, @@ -25,6 +24,7 @@ project_option, ) from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings +from flytekit.configuration.plugin import plugin from flytekit.core import context_manager from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 47f8a63cf3..8c229f71f9 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -236,7 +236,7 @@ def wf(a: float) -> typing.List[float]: assert wf(a=3.14) == [9.8596, 9.8596] -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_notebook_task(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index d450f76f3d..4956cbc88d 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -205,7 +205,7 @@ def test_task_serialization_deserialization_with_secret(sql_server): assert r.iat[0, 0] == 1 -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_sql_task(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/tests/flytekit/unit/cli/pyflyte/test_backfill.py b/tests/flytekit/unit/cli/pyflyte/test_backfill.py index 9918e6233a..ee61b25dcf 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_backfill.py +++ b/tests/flytekit/unit/cli/pyflyte/test_backfill.py @@ -20,7 +20,7 @@ def test_resolve_backfill_window(): resolve_backfill_window() -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) def test_pyflyte_backfill(mock_remote): mock_remote.generate_console_url.return_value = "ex" runner = CliRunner() diff --git a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py index 4f707f640e..7b0571b71f 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_launchplan.py +++ b/tests/flytekit/unit/cli/pyflyte/test_launchplan.py @@ -6,7 +6,7 @@ from flytekit.remote import FlyteRemote -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @pytest.mark.parametrize( ("action", "expected_state"), [ diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index 2c84fe8442..997798081a 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -5,7 +5,7 @@ from flytekit.configuration.plugin import FlytekitPlugin, get_plugin -@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +@patch("flytekit.configuration.plugin.entry_points") def test_get_plugin_default(entry_points): entry_points.side_effect = lambda *args, **kwargs: [] @@ -13,7 +13,7 @@ def test_get_plugin_default(entry_points): assert default_plugin is FlytekitPlugin -@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +@patch("flytekit.configuration.plugin.entry_points") def test_get_plugin_load_other_plugin(entry_points, caplog): loaded_plugin_1 = Mock() entry_1 = Mock() @@ -52,7 +52,7 @@ def click_main(config): pass -@patch("flytekit.clis.sdk_in_container.plugin.entry_points") +@patch("flytekit.configuration.plugin.entry_points") def test_get_plugin_custom(entry_points): entry_1 = Mock() entry_1.name.side_effect = "custom_plugin" diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index f6668902fb..5bd702661f 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -9,9 +9,9 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context -from flytekit.clis.sdk_in_container.plugin import FlytekitPlugin from flytekit.configuration import Config from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR +from flytekit.configuration.plugin import FlytekitPlugin from flytekit.core import context_manager from flytekit.core.context_manager import FlyteContextManager from flytekit.remote.remote import FlyteRemote @@ -48,7 +48,7 @@ def reset_flytectl_config_env_var() -> pytest.fixture(): return os.environ[FLYTECTL_CONFIG_ENV_VAR] -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote") +@mock.patch("flytekit.configuration.plugin.FlyteRemote") def test_get_remote(mock_remote, reset_flytectl_config_env_var): r = FlytekitPlugin.get_remote(None, "p", "d") assert r is not None @@ -57,7 +57,7 @@ def test_get_remote(mock_remote, reset_flytectl_config_env_var): ) -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote") +@mock.patch("flytekit.configuration.plugin.FlyteRemote") def test_saving_remote(mock_remote): mock_context = mock.MagicMock mock_context.obj = {} @@ -79,7 +79,7 @@ def test_register_with_no_package_or_module_argument(): ) -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_with_no_output_dir_passed(mock_client, mock_remote): ctx = FlyteContextManager.current_context() @@ -101,7 +101,7 @@ def test_register_with_no_output_dir_passed(mock_client, mock_remote): shutil.rmtree("core1") -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_register_shell_task(mock_client, mock_remote): mock_remote._client = mock_client @@ -121,7 +121,7 @@ def test_register_shell_task(mock_client, mock_remote): shutil.rmtree("core2") -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_non_fast_register(mock_client, mock_remote): ctx = FlyteContextManager.current_context() @@ -141,7 +141,7 @@ def test_non_fast_register(mock_client, mock_remote): shutil.rmtree("core2") -@mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_non_fast_register_require_version(mock_client, mock_remote): mock_remote._client = mock_client diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index 65ba549eb2..345425f0b0 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -38,7 +38,7 @@ def remote(): ], ) def test_pyflyte_run_wf(remote, remote_flag): - with mock.patch("flytekit.clis.sdk_in_container.plugin.FlyteRemote"): + with mock.patch("flytekit.configuration.plugin.FlyteRemote"): runner = CliRunner() module_path = WORKFLOW_FILE result = runner.invoke( From 1aa63356c2de9080d6efbdcc3a61a970f735df81 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 15:15:52 -0500 Subject: [PATCH 06/19] Use USE_DEFAULT_FLYTEKIT_PLUGIN Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index d5dbabdaee..9dc98c0cff 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -50,8 +50,8 @@ def get_plugin(): # Ensure that plugin is always configured to FlytekitPlugin during pytest runs -# Set USE_FLYTEKIT_PLUGIN=0 for testing other plugins -if "pytest" in sys.modules and os.environ.get("USE_FLYTEKIT_PLUGIN", "1") == "1": +# Set USE_DEFAULT_FLYTEKIT_PLUGIN=0 for testing other plugins +if "pytest" in sys.modules and os.environ.get("USE_DEFAULT_FLYTEKIT_PLUGIN", "1") == "1": plugin = FlytekitPlugin else: plugin = get_plugin() From c56079be0e4b86c83bd586e5de6c55b63caf548e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 15:44:43 -0500 Subject: [PATCH 07/19] pragma for coverage Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 9dc98c0cff..1cc0977ed4 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -20,7 +20,7 @@ def get_remote( if cfg_file is None: cfg_obj = Config.for_sandbox() cli_logger.info("No config files found, creating remote with sandbox config") - else: + else: # pragma: no cover cfg_obj = Config.auto(config) cli_logger.info(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else "")) return FlyteRemote( @@ -53,5 +53,5 @@ def get_plugin(): # Set USE_DEFAULT_FLYTEKIT_PLUGIN=0 for testing other plugins if "pytest" in sys.modules and os.environ.get("USE_DEFAULT_FLYTEKIT_PLUGIN", "1") == "1": plugin = FlytekitPlugin -else: +else: # pragma: no cover plugin = get_plugin() From bffdd495d7fb4dbf7a389c0adc1b81a50c39d50b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 15:50:09 -0500 Subject: [PATCH 08/19] Rename to USE_DEFAULT_FLYTEKIT_PLUGIN_FOR_TESTING Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 1cc0977ed4..f3216bc6cc 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -50,8 +50,8 @@ def get_plugin(): # Ensure that plugin is always configured to FlytekitPlugin during pytest runs -# Set USE_DEFAULT_FLYTEKIT_PLUGIN=0 for testing other plugins -if "pytest" in sys.modules and os.environ.get("USE_DEFAULT_FLYTEKIT_PLUGIN", "1") == "1": +# Set USE_DEFAULT_FLYTEKIT_PLUGIN_FOR_TESTING=0 for testing other plugins +if "pytest" in sys.modules and os.environ.get("USE_DEFAULT_FLYTEKIT_PLUGIN_FOR_TESTING", "1") == "1": plugin = FlytekitPlugin else: # pragma: no cover plugin = get_plugin() From 9b90029b71232340e3378407abffb780596df9ca Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 13 Dec 2023 22:51:53 -0500 Subject: [PATCH 09/19] Patch plugin when running tests Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 9 +-------- tests/flytekit/conftest.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 8 deletions(-) create mode 100644 tests/flytekit/conftest.py diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index f3216bc6cc..eb41768de6 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -1,5 +1,3 @@ -import os -import sys from typing import Optional from click import Command @@ -49,9 +47,4 @@ def get_plugin(): return plugin_to_load.load() -# Ensure that plugin is always configured to FlytekitPlugin during pytest runs -# Set USE_DEFAULT_FLYTEKIT_PLUGIN_FOR_TESTING=0 for testing other plugins -if "pytest" in sys.modules and os.environ.get("USE_DEFAULT_FLYTEKIT_PLUGIN_FOR_TESTING", "1") == "1": - plugin = FlytekitPlugin -else: # pragma: no cover - plugin = get_plugin() +plugin = get_plugin() diff --git a/tests/flytekit/conftest.py b/tests/flytekit/conftest.py new file mode 100644 index 0000000000..6b4b6c0541 --- /dev/null +++ b/tests/flytekit/conftest.py @@ -0,0 +1,14 @@ +import pytest + +import flytekit.clis.sdk_in_container.helpers +import flytekit.clis.sdk_in_container.pyflyte +from flytekit.configuration.plugin import FlytekitPlugin + + +@pytest.fixture(autouse=True, scope="session") +def configure_plugin(): + """If a plugin is installed then the plugin variable points to a external plugin. + For testing, we want to test against flytekit's own plugin, so we override the plugins.""" + flytekit.configuration.plugin.plugin = FlytekitPlugin + flytekit.clis.sdk_in_container.pyflyte.plugin = FlytekitPlugin + flytekit.clis.sdk_in_container.helpers.plugin = FlytekitPlugin From d0dae6a34ff3a935808653a1ef94c9035f2dc189 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 14 Dec 2023 08:18:09 -0500 Subject: [PATCH 10/19] Better solution for managing the plugin state Signed-off-by: Thomas J. Fan --- flytekit/clis/sdk_in_container/helpers.py | 4 ++-- flytekit/clis/sdk_in_container/pyflyte.py | 4 ++-- flytekit/clis/sdk_in_container/run.py | 4 ++-- flytekit/configuration/plugin.py | 9 +++++++-- tests/flytekit/conftest.py | 11 ++++------- tests/flytekit/unit/cli/pyflyte/test_plugin.py | 8 ++++---- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 1f81e00521..5ec4b9b262 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -5,7 +5,7 @@ from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE from flytekit.configuration import ImageConfig -from flytekit.configuration.plugin import plugin +from flytekit.configuration.plugin import get_plugin from flytekit.remote.remote import FlyteRemote FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" @@ -32,7 +32,7 @@ def get_and_save_remote_with_click_context( if ctx.obj.get(FLYTE_REMOTE_INSTANCE_KEY) is not None: return ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - r = plugin.get_remote(cfg_file_location, project, domain, data_upload_location) + r = get_plugin().get_remote(cfg_file_location, project, domain, data_upload_location) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 1a93bc820c..3171c2bcfc 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -22,7 +22,7 @@ from flytekit.clis.version import info from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE from flytekit.configuration.internal import LocalSDK -from flytekit.configuration.plugin import plugin +from flytekit.configuration.plugin import get_plugin from flytekit.loggers import cli_logger @@ -89,7 +89,7 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: bool): main.add_command(get) main.epilog -plugin.configure_pyflyte_cli(main) +get_plugin().configure_pyflyte_cli(main) if __name__ == "__main__": main() diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index c64d329ba4..2bf8708268 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -24,7 +24,7 @@ project_option, ) from flytekit.configuration import DefaultImages, FastSerializationSettings, ImageConfig, SerializationSettings -from flytekit.configuration.plugin import plugin +from flytekit.configuration.plugin import get_plugin from flytekit.core import context_manager from flytekit.core.base_task import PythonTask from flytekit.core.data_persistence import FileAccessProvider @@ -262,7 +262,7 @@ def remote_instance(self) -> FlyteRemote: data_upload_location = None if self.is_remote: data_upload_location = remote_fs.REMOTE_PLACEHOLDER - self._remote = plugin.get_remote(self.config_file, self.project, self.domain, data_upload_location) + self._remote = get_plugin().get_remote(self.config_file, self.project, self.domain, data_upload_location) return self._remote @property diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index eb41768de6..ce29fb51fe 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -31,7 +31,7 @@ def configure_pyflyte_cli(main: Command) -> Command: return main -def get_plugin(): +def _get_plugin(): """Get plugin for entrypoint.""" plugins = list(entry_points(group="flytekit.plugin")) @@ -47,4 +47,9 @@ def get_plugin(): return plugin_to_load.load() -plugin = get_plugin() +_GLOBAL_PLUGIN_STATE = {"plugin": _get_plugin()} + + +def get_plugin(): + """Get current plugin""" + return _GLOBAL_PLUGIN_STATE["plugin"] diff --git a/tests/flytekit/conftest.py b/tests/flytekit/conftest.py index 6b4b6c0541..72cb39074a 100644 --- a/tests/flytekit/conftest.py +++ b/tests/flytekit/conftest.py @@ -1,14 +1,11 @@ import pytest -import flytekit.clis.sdk_in_container.helpers -import flytekit.clis.sdk_in_container.pyflyte +import flytekit.configuration.plugin from flytekit.configuration.plugin import FlytekitPlugin @pytest.fixture(autouse=True, scope="session") def configure_plugin(): - """If a plugin is installed then the plugin variable points to a external plugin. - For testing, we want to test against flytekit's own plugin, so we override the plugins.""" - flytekit.configuration.plugin.plugin = FlytekitPlugin - flytekit.clis.sdk_in_container.pyflyte.plugin = FlytekitPlugin - flytekit.clis.sdk_in_container.helpers.plugin = FlytekitPlugin + """If a plugin is installed then the global plugin refers to an external plugin. + For testing, we want to test against flytekit's own plugin, so we override the state.""" + flytekit.configuration.plugin._GLOBAL_PLUGIN_STATE["plugin"] = FlytekitPlugin diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index 997798081a..3bd80ad9cc 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -2,14 +2,14 @@ import click -from flytekit.configuration.plugin import FlytekitPlugin, get_plugin +from flytekit.configuration.plugin import FlytekitPlugin, _get_plugin @patch("flytekit.configuration.plugin.entry_points") def test_get_plugin_default(entry_points): entry_points.side_effect = lambda *args, **kwargs: [] - default_plugin = get_plugin() + default_plugin = _get_plugin() assert default_plugin is FlytekitPlugin @@ -23,7 +23,7 @@ def test_get_plugin_load_other_plugin(entry_points, caplog): entry_2 = Mock() entry_points.side_effect = lambda *args, **kwargs: [entry_1, entry_2] - plugin = get_plugin() + plugin = _get_plugin() assert plugin is loaded_plugin_1 assert entry_1.load.call_count == 1 @@ -60,7 +60,7 @@ def test_get_plugin_custom(entry_points): entry_points.side_effect = lambda *args, **kwargs: [entry_1] - plugin = get_plugin() + plugin = _get_plugin() assert plugin is CustomPlugin assert not click_main.params[0].hidden From 9cf1c62a637519db03793bad0497fc3e23d5548a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 14 Dec 2023 08:48:23 -0500 Subject: [PATCH 11/19] Use a better name Signed-off-by: Thomas J. Fan --- tests/flytekit/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/conftest.py b/tests/flytekit/conftest.py index 72cb39074a..18fa6fc2f0 100644 --- a/tests/flytekit/conftest.py +++ b/tests/flytekit/conftest.py @@ -8,4 +8,4 @@ def configure_plugin(): """If a plugin is installed then the global plugin refers to an external plugin. For testing, we want to test against flytekit's own plugin, so we override the state.""" - flytekit.configuration.plugin._GLOBAL_PLUGIN_STATE["plugin"] = FlytekitPlugin + flytekit.configuration.plugin._GLOBAL_CONFIG["plugin"] = FlytekitPlugin From f4a75b4be46954a00b793a68c7af879f0cfd2e3d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 14 Dec 2023 08:48:41 -0500 Subject: [PATCH 12/19] Use a better name Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index ce29fb51fe..f80b8ce28e 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -47,9 +47,9 @@ def _get_plugin(): return plugin_to_load.load() -_GLOBAL_PLUGIN_STATE = {"plugin": _get_plugin()} +_GLOBAL_CONFIG = {"plugin": _get_plugin()} def get_plugin(): """Get current plugin""" - return _GLOBAL_PLUGIN_STATE["plugin"] + return _GLOBAL_CONFIG["plugin"] From cb90c9ac3e84688b3583b6e609758e4beed78c20 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 14 Dec 2023 11:19:21 -0500 Subject: [PATCH 13/19] Adds more docs Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index f80b8ce28e..19818e4e6d 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -1,3 +1,22 @@ +"""Defines a plugin API allowing other libraries to modify the behavior of flytekit. + +Libraries can register by defining an object that follows the same API as FlytekitPlugin +and providing entry pont with the group name "flytekit.plugin". In `setuptools`, +you can specific them with: + +```python +setup(entry_points={ + "flytekit.plugin": ["my_plugin=..."] +}) +``` + +or in pyproject.toml: + +```toml +[project.entry-points."flytekit.plugin"] +my_plugin = "..." +``` +""" from typing import Optional from click import Command From 1bd720366d730a250cb91e8b71a24046cf6e243c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 15 Dec 2023 09:21:05 -0500 Subject: [PATCH 14/19] Better variable names Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 18 +++++++++--------- tests/flytekit/unit/cli/pyflyte/test_plugin.py | 8 ++++---- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 19818e4e6d..3fb631140c 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -1,20 +1,20 @@ """Defines a plugin API allowing other libraries to modify the behavior of flytekit. Libraries can register by defining an object that follows the same API as FlytekitPlugin -and providing entry pont with the group name "flytekit.plugin". In `setuptools`, +and providing an entrypoint with the group name "flytekit.plugin". In `setuptools`, you can specific them with: ```python setup(entry_points={ - "flytekit.plugin": ["my_plugin=..."] + "flytekit.configuration.plugin": ["my_plugin=my_module:MyCustomPlugin"] }) ``` or in pyproject.toml: ```toml -[project.entry-points."flytekit.plugin"] -my_plugin = "..." +[project.entry-points."flytekit.configuration.plugin"] +my_plugin = "my_module:MyCustomPlugin" ``` """ from typing import Optional @@ -50,23 +50,23 @@ def configure_pyflyte_cli(main: Command) -> Command: return main -def _get_plugin(): - """Get plugin for entrypoint.""" - plugins = list(entry_points(group="flytekit.plugin")) +def _get_plugin_from_entrypoint(): + """Get plugin from entrypoint.""" + plugins = list(entry_points(group="flytekit.configuration.plugin")) if not plugins: return FlytekitPlugin if len(plugins) >= 2: plugin_names = [p.name for p in plugins] - cli_logger.info(f"Multiple plugins seen for flytekit.plugin: {plugin_names}") + cli_logger.info(f"Multiple plugins seen for flytekit.configuration.plugin: {plugin_names}") plugin_to_load = plugins[0] cli_logger.info(f"Loading plugin: {plugin_to_load.name}") return plugin_to_load.load() -_GLOBAL_CONFIG = {"plugin": _get_plugin()} +_GLOBAL_CONFIG = {"plugin": _get_plugin_from_entrypoint()} def get_plugin(): diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index 3bd80ad9cc..f67ac5dbb4 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -2,14 +2,14 @@ import click -from flytekit.configuration.plugin import FlytekitPlugin, _get_plugin +from flytekit.configuration.plugin import FlytekitPlugin, _get_plugin_from_entrypoint @patch("flytekit.configuration.plugin.entry_points") def test_get_plugin_default(entry_points): entry_points.side_effect = lambda *args, **kwargs: [] - default_plugin = _get_plugin() + default_plugin = _get_plugin_from_entrypoint() assert default_plugin is FlytekitPlugin @@ -23,7 +23,7 @@ def test_get_plugin_load_other_plugin(entry_points, caplog): entry_2 = Mock() entry_points.side_effect = lambda *args, **kwargs: [entry_1, entry_2] - plugin = _get_plugin() + plugin = _get_plugin_from_entrypoint() assert plugin is loaded_plugin_1 assert entry_1.load.call_count == 1 @@ -60,7 +60,7 @@ def test_get_plugin_custom(entry_points): entry_points.side_effect = lambda *args, **kwargs: [entry_1] - plugin = _get_plugin() + plugin = _get_plugin_from_entrypoint() assert plugin is CustomPlugin assert not click_main.params[0].hidden From 2ab3c64fa5ef0481a096c27654c0a9f2bae0d807 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 15 Dec 2023 09:56:42 -0500 Subject: [PATCH 15/19] CLN Use group name as variable Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 3fb631140c..00efa63022 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -52,14 +52,15 @@ def configure_pyflyte_cli(main: Command) -> Command: def _get_plugin_from_entrypoint(): """Get plugin from entrypoint.""" - plugins = list(entry_points(group="flytekit.configuration.plugin")) + group = "flytekit.configuration.plugin" + plugins = list(entry_points(group=group)) if not plugins: return FlytekitPlugin if len(plugins) >= 2: plugin_names = [p.name for p in plugins] - cli_logger.info(f"Multiple plugins seen for flytekit.configuration.plugin: {plugin_names}") + cli_logger.info(f"Multiple plugins seen for {group}: {plugin_names}") plugin_to_load = plugins[0] cli_logger.info(f"Loading plugin: {plugin_to_load.name}") From dde46d9126aa936b9ec74ec63fa9f140c6bdba79 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 15 Dec 2023 10:01:57 -0500 Subject: [PATCH 16/19] Add protocol Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 15 ++++++++++++++- tests/flytekit/unit/cli/pyflyte/test_plugin.py | 6 +++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 00efa63022..1b1804173d 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -17,7 +17,7 @@ my_plugin = "my_module:MyCustomPlugin" ``` """ -from typing import Optional +from typing import Optional, Protocol, runtime_checkable from click import Command from importlib_metadata import entry_points @@ -27,6 +27,19 @@ from flytekit.remote import FlyteRemote +@runtime_checkable +class FlytekitPluginProtocol(Protocol): + @staticmethod + def get_remote( + config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None + ) -> FlyteRemote: + """Get FlyteRemote object for CLI session.""" + + @staticmethod + def configure_pyflyte_cli(main: Command) -> Command: + """Configure pyflyte's CLI.""" + + class FlytekitPlugin: @staticmethod def get_remote( diff --git a/tests/flytekit/unit/cli/pyflyte/test_plugin.py b/tests/flytekit/unit/cli/pyflyte/test_plugin.py index f67ac5dbb4..6e3423a1ca 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_plugin.py +++ b/tests/flytekit/unit/cli/pyflyte/test_plugin.py @@ -2,7 +2,7 @@ import click -from flytekit.configuration.plugin import FlytekitPlugin, _get_plugin_from_entrypoint +from flytekit.configuration.plugin import FlytekitPlugin, FlytekitPluginProtocol, _get_plugin_from_entrypoint @patch("flytekit.configuration.plugin.entry_points") @@ -67,3 +67,7 @@ def test_get_plugin_custom(entry_points): plugin.configure_pyflyte_cli(click_main) assert click_main.params[0].hidden + + +def test_plugin_follow_protocol(): + assert issubclass(FlytekitPlugin, FlytekitPluginProtocol) From 279ece89fddd088cdd52c8922dda92eb1fe71f34 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 18 Dec 2023 13:57:48 -0500 Subject: [PATCH 17/19] Makes secret groups optional and configurable Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 21 +++++++++++++++---- flytekit/core/context_manager.py | 19 ++++++++++++----- flytekit/models/security.py | 5 +++-- .../unit/core/test_context_manager.py | 20 ++++++++++++++++++ .../unit/models/core/test_security.py | 12 +++++++++++ 5 files changed, 66 insertions(+), 11 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 1b1804173d..82535b48d9 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -17,14 +17,16 @@ my_plugin = "my_module:MyCustomPlugin" ``` """ -from typing import Optional, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable from click import Command from importlib_metadata import entry_points from flytekit.configuration import Config, get_config_file from flytekit.loggers import cli_logger -from flytekit.remote import FlyteRemote + +if TYPE_CHECKING: + from flytekit.remote import FlyteRemote @runtime_checkable @@ -32,20 +34,26 @@ class FlytekitPluginProtocol(Protocol): @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None - ) -> FlyteRemote: + ) -> "FlyteRemote": """Get FlyteRemote object for CLI session.""" @staticmethod def configure_pyflyte_cli(main: Command) -> Command: """Configure pyflyte's CLI.""" + @staticmethod + def secret_requires_group() -> bool: + """Return True if secrets require group entry.""" + class FlytekitPlugin: @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None - ) -> FlyteRemote: + ) -> "FlyteRemote": """Get FlyteRemote object for CLI session.""" + from flytekit.remote import FlyteRemote + cfg_file = get_config_file(config) if cfg_file is None: cfg_obj = Config.for_sandbox() @@ -62,6 +70,11 @@ def configure_pyflyte_cli(main: Command) -> Command: """Configure pyflyte's CLI.""" return main + @staticmethod + def secret_requires_group() -> bool: + """Return True if secrets require group entry.""" + return True + def _get_plugin_from_entrypoint(): """Get plugin from entrypoint.""" diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 833c7d8562..4e95292201 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -28,6 +28,7 @@ from typing import Generator, List, Optional, Union from flytekit.configuration import Config, SecretsConfig, SerializationSettings +from flytekit.configuration.plugin import get_plugin from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider @@ -350,7 +351,11 @@ def __getattr__(self, item: str) -> _GroupSecrets: return self._GroupSecrets(item, self) def get( - self, group: str, key: Optional[str] = None, group_version: Optional[str] = None, encode_mode: str = "r" + self, + group: Optional[str] = None, + key: Optional[str] = None, + group_version: Optional[str] = None, + encode_mode: str = "r", ) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError @@ -370,7 +375,9 @@ def get( f"in Env Var:{env_var} and FilePath: {fpath}" ) - def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get_secrets_env_var( + self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None + ) -> str: """ Returns a string that matches the ENV Variable to look for the secrets """ @@ -378,7 +385,9 @@ def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_versi l = [k.upper() for k in filter(None, (group, group_version, key))] return f"{self._env_prefix}{'_'.join(l)}" - def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get_secrets_file( + self, group: Optional[str] = None, key: Optional[str] = None, group_version: Optional[str] = None + ) -> str: """ Returns a path that matches the file to look for the secrets """ @@ -388,8 +397,8 @@ def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: return os.path.join(self._base_dir, *l) @staticmethod - def check_group_key(group: str): - if group is None or group == "": + def check_group_key(group: Optional[str]): + if get_plugin().secret_requires_group() and (group is None or group == ""): raise ValueError("secrets group is a mandatory field.") diff --git a/flytekit/models/security.py b/flytekit/models/security.py index 9af90a4b8a..b1a9a890dc 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -4,6 +4,7 @@ from flyteidl.core import security_pb2 as _sec +from flytekit.configuration.plugin import get_plugin from flytekit.models import common as _common @@ -35,13 +36,13 @@ class MountType(Enum): Caution: May not be supported in all environments """ - group: str + group: Optional[str] = None key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY def __post_init__(self): - if self.group is None: + if get_plugin().secret_requires_group() and self.group is None: raise ValueError("Group is a required parameter") def to_flyte_idl(self) -> _sec.Secret: diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 2ec7eb8e19..8634170c49 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -1,6 +1,8 @@ import base64 import os from datetime import datetime +from pathlib import Path +from unittest.mock import Mock, patch import mock import py @@ -128,6 +130,24 @@ def test_secrets_manager_get_envvar(): assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP" +@patch("flytekit.core.context_manager.get_plugin") +def test_secret_manager_no_group(get_plugin_mock): + plugin_mock = Mock() + plugin_mock.secret_requires_group.return_value = False + get_plugin_mock.return_value = plugin_mock + + sec = SecretsManager() + cfg = SecretsConfig.auto() + sec.check_group_key(None) + sec.check_group_key("") + + assert sec.get_secrets_env_var(key="ABC") == f"{cfg.env_prefix}ABC" + + default_path = Path(cfg.default_dir) + expected_path = default_path / f"{cfg.file_prefix}abc" + assert sec.get_secrets_file(key="ABC") == str(expected_path) + + def test_secrets_manager_get_file(): sec = SecretsManager() with pytest.raises(ValueError): diff --git a/tests/flytekit/unit/models/core/test_security.py b/tests/flytekit/unit/models/core/test_security.py index c2933f9353..2c6766821f 100644 --- a/tests/flytekit/unit/models/core/test_security.py +++ b/tests/flytekit/unit/models/core/test_security.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock, patch + from flytekit.models.security import Secret @@ -11,3 +13,13 @@ def test_secret(): obj2 = Secret.from_flyte_idl(obj.to_flyte_idl()) assert obj2.key is None assert obj2.group_version == "v1" + + +@patch("flytekit.models.security.get_plugin") +def test_secret_no_group(get_plugin_mock): + plugin_mock = Mock() + plugin_mock.secret_requires_group.return_value = False + get_plugin_mock.return_value = plugin_mock + + s = Secret(key="key") + assert s.group is None From c7e26ddd2242d65dd3a934e29bc9c1f7c76d615a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 21 Dec 2023 10:54:48 -0500 Subject: [PATCH 18/19] Fix circular reference Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index af78479120..07596673da 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -50,7 +50,7 @@ class FlytekitPlugin: @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None - ) -> FlyteRemote: + ) -> "FlyteRemote": """Get FlyteRemote object for CLI session.""" from flytekit.remote import FlyteRemote From 8fdb32c6f85c2bf90e4e6525131f69e9ccb37449 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 21 Dec 2023 11:14:45 -0500 Subject: [PATCH 19/19] Fixes circular import Signed-off-by: Thomas J. Fan --- flytekit/configuration/plugin.py | 12 ++++-------- flytekit/core/context_manager.py | 3 ++- flytekit/models/security.py | 3 ++- tests/flytekit/unit/core/test_context_manager.py | 9 +++++---- tests/flytekit/unit/models/core/test_security.py | 9 +++++---- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/flytekit/configuration/plugin.py b/flytekit/configuration/plugin.py index 07596673da..fe96324c85 100644 --- a/flytekit/configuration/plugin.py +++ b/flytekit/configuration/plugin.py @@ -17,16 +17,14 @@ my_plugin = "my_module:MyCustomPlugin" ``` """ -from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable +from typing import Optional, Protocol, runtime_checkable from click import Command from importlib_metadata import entry_points from flytekit.configuration import Config, get_config_file from flytekit.loggers import logger - -if TYPE_CHECKING: - from flytekit.remote import FlyteRemote +from flytekit.remote import FlyteRemote @runtime_checkable @@ -34,7 +32,7 @@ class FlytekitPluginProtocol(Protocol): @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None - ) -> "FlyteRemote": + ) -> FlyteRemote: """Get FlyteRemote object for CLI session.""" @staticmethod @@ -50,10 +48,8 @@ class FlytekitPlugin: @staticmethod def get_remote( config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None - ) -> "FlyteRemote": + ) -> FlyteRemote: """Get FlyteRemote object for CLI session.""" - from flytekit.remote import FlyteRemote - cfg_file = get_config_file(config) if cfg_file is None: cfg_obj = Config.for_sandbox() diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 4e95292201..d30eba0918 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -28,7 +28,6 @@ from typing import Generator, List, Optional, Union from flytekit.configuration import Config, SecretsConfig, SerializationSettings -from flytekit.configuration.plugin import get_plugin from flytekit.core import mock_stats, utils from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider @@ -398,6 +397,8 @@ def get_secrets_file( @staticmethod def check_group_key(group: Optional[str]): + from flytekit.configuration.plugin import get_plugin + if get_plugin().secret_requires_group() and (group is None or group == ""): raise ValueError("secrets group is a mandatory field.") diff --git a/flytekit/models/security.py b/flytekit/models/security.py index b1a9a890dc..a9ee7e7cb9 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -4,7 +4,6 @@ from flyteidl.core import security_pb2 as _sec -from flytekit.configuration.plugin import get_plugin from flytekit.models import common as _common @@ -42,6 +41,8 @@ class MountType(Enum): mount_requirement: MountType = MountType.ANY def __post_init__(self): + from flytekit.configuration.plugin import get_plugin + if get_plugin().secret_requires_group() and self.group is None: raise ValueError("Group is a required parameter") diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 8634170c49..ca22f359c9 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -2,12 +2,13 @@ import os from datetime import datetime from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import Mock import mock import py import pytest +import flytekit.configuration.plugin from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -130,11 +131,11 @@ def test_secrets_manager_get_envvar(): assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP" -@patch("flytekit.core.context_manager.get_plugin") -def test_secret_manager_no_group(get_plugin_mock): +def test_secret_manager_no_group(monkeypatch): plugin_mock = Mock() plugin_mock.secret_requires_group.return_value = False - get_plugin_mock.return_value = plugin_mock + mock_global_plugin = {"plugin": plugin_mock} + monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin) sec = SecretsManager() cfg = SecretsConfig.auto() diff --git a/tests/flytekit/unit/models/core/test_security.py b/tests/flytekit/unit/models/core/test_security.py index 2c6766821f..a7ed006174 100644 --- a/tests/flytekit/unit/models/core/test_security.py +++ b/tests/flytekit/unit/models/core/test_security.py @@ -1,5 +1,6 @@ -from unittest.mock import Mock, patch +from unittest.mock import Mock +import flytekit.configuration.plugin from flytekit.models.security import Secret @@ -15,11 +16,11 @@ def test_secret(): assert obj2.group_version == "v1" -@patch("flytekit.models.security.get_plugin") -def test_secret_no_group(get_plugin_mock): +def test_secret_no_group(monkeypatch): plugin_mock = Mock() plugin_mock.secret_requires_group.return_value = False - get_plugin_mock.return_value = plugin_mock + mock_global_plugin = {"plugin": plugin_mock} + monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin) s = Secret(key="key") assert s.group is None