From 707fc03b88ec08fb10d60ac642f3c80e275bd90c Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Fri, 17 Feb 2023 16:59:36 -0800 Subject: [PATCH] Delay initialization of SynchronousFlyteClient in FlyteRemote (#1514) * Delay initialization of SynchronousFlyteClient in FlyteRemote Signed-off-by: Eduardo Apolinario * Fix spark plugin flyteremote test. Signed-off-by: Eduardo Apolinario * Lint Signed-off-by: Eduardo Apolinario --------- Signed-off-by: Eduardo Apolinario Co-authored-by: Eduardo Apolinario --- flytekit/remote/remote.py | 6 ++- .../tests/test_remote_register.py | 1 + tests/flytekit/unit/cli/pyflyte/test_run.py | 19 ++++++-- tests/flytekit/unit/core/test_signal.py | 14 ++++-- tests/flytekit/unit/remote/test_remote.py | 48 +++++++++---------- 5 files changed, 53 insertions(+), 35 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index cf9533d7da..93badd5374 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -165,7 +165,8 @@ def __init__( if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") - self._client = SynchronousFlyteClient(config.platform, **kwargs) + self._kwargs = kwargs + self._client_initialized = False self._config = config # read config files, env vars, host, ssl options for admin client self._default_project = default_project @@ -187,6 +188,9 @@ def context(self) -> FlyteContext: @property def client(self) -> SynchronousFlyteClient: """Return a SynchronousFlyteClient for additional operations.""" + if not self._client_initialized: + self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs) + self._client_initialized = True return self._client @property diff --git a/plugins/flytekit-spark/tests/test_remote_register.py b/plugins/flytekit-spark/tests/test_remote_register.py index 8eaf8a0794..3bb65d09bc 100644 --- a/plugins/flytekit-spark/tests/test_remote_register.py +++ b/plugins/flytekit-spark/tests/test_remote_register.py @@ -21,6 +21,7 @@ def my_python_task(a: str) -> int: mock_client = MagicMock() remote._client = mock_client + remote._client_initialized = True remote.register_task( my_spark, diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index a75d2d45c9..b211153f44 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -35,12 +35,21 @@ DIR_NAME = os.path.dirname(os.path.realpath(__file__)) -def test_pyflyte_run_wf(): - runner = CliRunner() - module_path = WORKFLOW_FILE - result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) +@pytest.fixture +def remote(): + with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client = mock_client + return flyte_remote - assert result.exit_code == 0 + +def test_pyflyte_run_wf(remote): + with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"): + runner = CliRunner() + module_path = WORKFLOW_FILE + result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) + + assert result.exit_code == 0 def test_imperative_wf(): diff --git a/tests/flytekit/unit/core/test_signal.py b/tests/flytekit/unit/core/test_signal.py index a37da8955f..a3bee2e4c7 100644 --- a/tests/flytekit/unit/core/test_signal.py +++ b/tests/flytekit/unit/core/test_signal.py @@ -1,3 +1,4 @@ +import pytest from flyteidl.admin.signal_pb2 import Signal, SignalList from mock import MagicMock @@ -8,7 +9,14 @@ from flytekit.remote.remote import FlyteRemote -def test_remote_list_signals(): +@pytest.fixture +def remote(): + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + return flyte_remote + + +def test_remote_list_signals(remote): ctx = FlyteContextManager.current_context() wfeid = WorkflowExecutionIdentifier("p", "d", "execid") signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl() @@ -22,13 +30,12 @@ def test_remote_list_signals(): mock_client = MagicMock() mock_client.list_signals.return_value = SignalList(signals=[signal], token="") - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client res = remote.list_signals("execid", "p", "d", limit=10) assert len(res) == 1 -def test_remote_set_signal(): +def test_remote_set_signal(remote): mock_client = MagicMock() def checker(request): @@ -37,6 +44,5 @@ def checker(request): mock_client.set_signal.side_effect = checker - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client remote.set_signal("sigid", "execid", 3) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index f9346e90e3..4b8f82fb7e 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -44,29 +44,36 @@ } -@patch("flytekit.clients.friendly.SynchronousFlyteClient") -def test_remote_fetch_execution(mock_client_manager): +@pytest.fixture +def remote(): + with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client: + flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + flyte_remote._client_initialized = True + flyte_remote._client = mock_client + return flyte_remote + + +def test_remote_fetch_execution(remote): admin_workflow_execution = Execution( id=WorkflowExecutionIdentifier("p1", "d1", "n1"), spec=MagicMock(), closure=MagicMock(), ) - mock_client = MagicMock() mock_client.get_execution.return_value = admin_workflow_execution - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client flyte_workflow_execution = remote.fetch_execution(name="n1") assert flyte_workflow_execution.id == admin_workflow_execution.id -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec): +@pytest.fixture +def mock_wf_exec(): + return patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") + + +def test_underscore_execute_uses_launch_plan_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client def local_assertions(*args, **kwargs): @@ -93,12 +100,9 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec): +def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec): mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client options = Options( @@ -124,14 +128,11 @@ def local_assertions(*args, **kwargs): ) -@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model") -def test_execute_with_wrong_input_key(mock_wf_exec): +def test_execute_with_wrong_input_key(remote, mock_wf_exec): # mock_url.get.return_value = "localhost" # mock_insecure.get.return_value = True mock_wf_exec.return_value = True mock_client = MagicMock() - - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") remote._client = mock_client mock_entity = MagicMock() @@ -162,7 +163,7 @@ def test_passing_of_kwargs(mock_client): "root_certificates": 5, "certificate_chain": 6, } - FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args) + FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args).client assert mock_client.called assert mock_client.call_args[1] == additional_args @@ -273,8 +274,8 @@ def get_compiled_workflow_closure(): return CompiledWorkflowClosure.from_flyte_idl(cwc_pb) -@patch("flytekit.remote.remote.SynchronousFlyteClient") -def test_fetch_lazy(mock_client): +def test_fetch_lazy(remote): + mock_client = remote._client mock_client.get_task.return_value = Task( id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0] ) @@ -284,7 +285,6 @@ def test_fetch_lazy(mock_client): closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()), ) - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") lw = remote.fetch_workflow_lazy(name="wn", version="v") assert isinstance(lw, LazyEntity) assert lw._getter @@ -309,8 +309,7 @@ def example_wf(t: datetime, v: int): tk(t=t, v=v) -@patch("flytekit.remote.remote.SynchronousFlyteClient") -def test_launch_backfill(mock_client): +def test_launch_backfill(remote): daily_lp = LaunchPlan.get_or_create( workflow=example_wf, name="daily2", @@ -336,6 +335,7 @@ def test_launch_backfill(mock_client): for k, v in m.items(): if isinstance(k, PythonTask): tasks.append(v) + mock_client = remote._client mock_client.get_launch_plan.return_value = ser_lp mock_client.get_workflow.return_value = Workflow( id=Identifier(ResourceType.WORKFLOW, "p", "d", "daily2", "v"), @@ -343,8 +343,6 @@ def test_launch_backfill(mock_client): compiled_workflow=CompiledWorkflowClosure(primary=ser_wf, sub_workflows=[], tasks=tasks) ), ) - remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") - remote._client = mock_client wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True) assert wf