Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delay initialization of SynchronousFlyteClient in FlyteRemote #1514

Merged
merged 4 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def __init__(
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")

self._client = SynchronousFlyteClient(config.platform, **kwargs)
self._kwargs = kwargs
self._client_initialized = False
self._config = config
# read config files, env vars, host, ssl options for admin client
self._default_project = default_project
Expand All @@ -187,6 +188,9 @@ def context(self) -> FlyteContext:
@property
def client(self) -> SynchronousFlyteClient:
"""Return a SynchronousFlyteClient for additional operations."""
if not self._client_initialized:
self._client = SynchronousFlyteClient(self.config.platform, **self._kwargs)
self._client_initialized = True
return self._client

@property
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def my_python_task(a: str) -> int:

mock_client = MagicMock()
remote._client = mock_client
remote._client_initialized = True

remote.register_task(
my_spark,
Expand Down
19 changes: 14 additions & 5 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,21 @@
DIR_NAME = os.path.dirname(os.path.realpath(__file__))


def test_pyflyte_run_wf():
runner = CliRunner()
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)
@pytest.fixture
def remote():
with mock.patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client = mock_client
return flyte_remote

assert result.exit_code == 0

def test_pyflyte_run_wf(remote):
with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"):
runner = CliRunner()
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)

assert result.exit_code == 0


def test_imperative_wf():
Expand Down
14 changes: 10 additions & 4 deletions tests/flytekit/unit/core/test_signal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from flyteidl.admin.signal_pb2 import Signal, SignalList
from mock import MagicMock

Expand All @@ -8,7 +9,14 @@
from flytekit.remote.remote import FlyteRemote


def test_remote_list_signals():
@pytest.fixture
def remote():
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client_initialized = True
return flyte_remote


def test_remote_list_signals(remote):
ctx = FlyteContextManager.current_context()
wfeid = WorkflowExecutionIdentifier("p", "d", "execid")
signal_id = SignalIdentifier(signal_id="sigid", execution_id=wfeid).to_flyte_idl()
Expand All @@ -22,13 +30,12 @@ def test_remote_list_signals():
mock_client = MagicMock()
mock_client.list_signals.return_value = SignalList(signals=[signal], token="")

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
res = remote.list_signals("execid", "p", "d", limit=10)
assert len(res) == 1


def test_remote_set_signal():
def test_remote_set_signal(remote):
mock_client = MagicMock()

def checker(request):
Expand All @@ -37,6 +44,5 @@ def checker(request):

mock_client.set_signal.side_effect = checker

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
remote.set_signal("sigid", "execid", 3)
48 changes: 23 additions & 25 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,36 @@
}


@patch("flytekit.clients.friendly.SynchronousFlyteClient")
def test_remote_fetch_execution(mock_client_manager):
@pytest.fixture
def remote():
with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
flyte_remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
flyte_remote._client_initialized = True
flyte_remote._client = mock_client
return flyte_remote


def test_remote_fetch_execution(remote):
admin_workflow_execution = Execution(
id=WorkflowExecutionIdentifier("p1", "d1", "n1"),
spec=MagicMock(),
closure=MagicMock(),
)

mock_client = MagicMock()
mock_client.get_execution.return_value = admin_workflow_execution

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client
flyte_workflow_execution = remote.fetch_execution(name="n1")
assert flyte_workflow_execution.id == admin_workflow_execution.id


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_underscore_execute_uses_launch_plan_attributes(mock_wf_exec):
@pytest.fixture
def mock_wf_exec():
return patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")


def test_underscore_execute_uses_launch_plan_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

def local_assertions(*args, **kwargs):
Expand All @@ -93,12 +100,9 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_underscore_execute_fall_back_remote_attributes(mock_wf_exec):
def test_underscore_execute_fall_back_remote_attributes(remote, mock_wf_exec):
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

options = Options(
Expand All @@ -124,14 +128,11 @@ def local_assertions(*args, **kwargs):
)


@patch("flytekit.remote.executions.FlyteWorkflowExecution.promote_from_model")
def test_execute_with_wrong_input_key(mock_wf_exec):
def test_execute_with_wrong_input_key(remote, mock_wf_exec):
# mock_url.get.return_value = "localhost"
# mock_insecure.get.return_value = True
mock_wf_exec.return_value = True
mock_client = MagicMock()

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

mock_entity = MagicMock()
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_passing_of_kwargs(mock_client):
"root_certificates": 5,
"certificate_chain": 6,
}
FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args)
FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain", **additional_args).client
assert mock_client.called
assert mock_client.call_args[1] == additional_args

Expand Down Expand Up @@ -273,8 +274,8 @@ def get_compiled_workflow_closure():
return CompiledWorkflowClosure.from_flyte_idl(cwc_pb)


@patch("flytekit.remote.remote.SynchronousFlyteClient")
def test_fetch_lazy(mock_client):
def test_fetch_lazy(remote):
mock_client = remote._client
mock_client.get_task.return_value = Task(
id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0]
)
Expand All @@ -284,7 +285,6 @@ def test_fetch_lazy(mock_client):
closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()),
)

remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
lw = remote.fetch_workflow_lazy(name="wn", version="v")
assert isinstance(lw, LazyEntity)
assert lw._getter
Expand All @@ -309,8 +309,7 @@ def example_wf(t: datetime, v: int):
tk(t=t, v=v)


@patch("flytekit.remote.remote.SynchronousFlyteClient")
def test_launch_backfill(mock_client):
def test_launch_backfill(remote):
daily_lp = LaunchPlan.get_or_create(
workflow=example_wf,
name="daily2",
Expand All @@ -336,15 +335,14 @@ def test_launch_backfill(mock_client):
for k, v in m.items():
if isinstance(k, PythonTask):
tasks.append(v)
mock_client = remote._client
mock_client.get_launch_plan.return_value = ser_lp
mock_client.get_workflow.return_value = Workflow(
id=Identifier(ResourceType.WORKFLOW, "p", "d", "daily2", "v"),
closure=WorkflowClosure(
compiled_workflow=CompiledWorkflowClosure(primary=ser_wf, sub_workflows=[], tasks=tasks)
),
)
remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1")
remote._client = mock_client

wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True)
assert wf