Skip to content

Commit

Permalink
Delay initialization of SynchronousFlyteClient in FlyteRemote (#1514)
Browse files Browse the repository at this point in the history
* Delay initialization of SynchronousFlyteClient in FlyteRemote

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix spark plugin flyteremote test.

Signed-off-by: Eduardo Apolinario <[email protected]>

* Lint

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Feb 18, 2023
1 parent 99d3d50 commit 707fc03
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 35 deletions.
6 changes: 5 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def __init__(
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")

self._client = SynchronousFlyteClient(config.platform, **kwargs)
self._kwargs = kwargs
self._client_initialized = False
self._config = config
# read config files, env vars, host, ssl options for admin client
self._default_project = default_project
Expand All @@ -187,6 +188,9 @@ def context(self) -> FlyteContext:
@property
def client(self) -> SynchronousFlyteClient:
"""Return a SynchronousFlyteClient for additional operations."""
if not self._client_initialized:
self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs)
self._client_initialized = True
return self._client

@property
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def my_python_task(a: str) -> int:

mock_client = MagicMock()
remote._client = mock_client
remote._client_initialized = True

remote.register_task(
my_spark,
Expand Down
19 changes: 14 additions & 5 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,21 @@
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


def test_pyflyte_run_wf():
runner = CliRunner()
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)
@pytest.fixture
def remote():
with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client = mock_client
return flyte_remote

assert result.exit_code == 0

def test_pyflyte_run_wf(remote):
with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"):
runner = CliRunner()
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)

assert result.exit_code == 0


def test_imperative_wf():
Expand Down
14 changes: 10 additions & 4 deletions tests/flytekit/unit/core/test_signal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from flyteidl.admin.signal_pb2 import Signal, SignalList
from mock import MagicMock

Expand All @@ -8,7 +9,14 @@
from flytekit.remote.remote import FlyteRemote


def test_remote_list_signals():
@pytest.fixture
def remote():
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client_initialized = True
return flyte_remote


def test_remote_list_signals(remote):
ctx = FlyteContextManager.current_context()
wfeid = WorkflowExecutionIdentifier("p", "d", "execid")
signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl()
Expand All @@ -22,13 +30,12 @@ def test_remote_list_signals():
mock_client = MagicMock()
mock_client.list_signals.return_value = SignalList(signals=[signal], token="")

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
res = remote.list_signals("execid", "p", "d", limit=10)
assert len(res) == 1


def test_remote_set_signal():
def test_remote_set_signal(remote):
mock_client = MagicMock()

def checker(request):
Expand All @@ -37,6 +44,5 @@ def checker(request):

mock_client.set_signal.side_effect = checker

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
remote.set_signal("sigid", "execid", 3)
48 changes: 23 additions & 25 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,36 @@
}


@patch("flytekit.clients.friendly.SynchronousFlyteClient")
def test_remote_fetch_execution(mock_client_manager):
@pytest.fixture
def remote():
with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client_initialized = True
flyte_remote._client = mock_client
return flyte_remote


def test_remote_fetch_execution(remote):
admin_workflow_execution = Execution(
id=WorkflowExecutionIdentifier("p1", "d1", "n1"),
spec=MagicMock(),
closure=MagicMock(),
)

mock_client = MagicMock()
mock_client.get_execution.return_value = admin_workflow_execution

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
flyte_workflow_execution = remote.fetch_execution(name="n1")
assert flyte_workflow_execution.id == admin_workflow_execution.id


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec):
@pytest.fixture
def mock_wf_exec():
return patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")


def test_underscore_execute_uses_launch_plan_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

def local_assertions(*args, **kwargs):
Expand All @@ -93,12 +100,9 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec):
def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

options = Options(
Expand All @@ -124,14 +128,11 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_execute_with_wrong_input_key(mock_wf_exec):
def test_execute_with_wrong_input_key(remote, mock_wf_exec):
# mock_url.get.return_value = "localhost"
# mock_insecure.get.return_value = True
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

mock_entity = MagicMock()
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_passing_of_kwargs(mock_client):
"root_certificates": 5,
"certificate_chain": 6,
}
FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args)
FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args).client
assert mock_client.called
assert mock_client.call_args[1] == additional_args

Expand Down Expand Up @@ -273,8 +274,8 @@ def get_compiled_workflow_closure():
return CompiledWorkflowClosure.from_flyte_idl(cwc_pb)


@patch("flytekit.remote.remote.SynchronousFlyteClient")
def test_fetch_lazy(mock_client):
def test_fetch_lazy(remote):
mock_client = remote._client
mock_client.get_task.return_value = Task(
id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0]
)
Expand All @@ -284,7 +285,6 @@ def test_fetch_lazy(mock_client):
closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()),
)

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
lw = remote.fetch_workflow_lazy(name="wn", version="v")
assert isinstance(lw, LazyEntity)
assert lw._getter
Expand All @@ -309,8 +309,7 @@ def example_wf(t: datetime, v: int):
tk(t=t, v=v)


@patch("flytekit.remote.remote.SynchronousFlyteClient")
def test_launch_backfill(mock_client):
def test_launch_backfill(remote):
daily_lp = LaunchPlan.get_or_create(
workflow=example_wf,
name="daily2",
Expand All @@ -336,15 +335,14 @@ def test_launch_backfill(mock_client):
for k, v in m.items():
if isinstance(k, PythonTask):
tasks.append(v)
mock_client = remote._client
mock_client.get_launch_plan.return_value = ser_lp
mock_client.get_workflow.return_value = Workflow(
id=Identifier(ResourceType.WORKFLOW, "p", "d", "daily2", "v"),
closure=WorkflowClosure(
compiled_workflow=CompiledWorkflowClosure(primary=ser_wf, sub_workflows=[], tasks=tasks)
),
)
remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True)
assert wf

0 comments on commit 707fc03

Please sign in to comment.