Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create default launch plan when executing WorkflowBase #707

Merged
merged 23 commits into from
Oct 28, 2021
Merged
54 changes: 28 additions & 26 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,33 +125,35 @@ def handler(*args, **kwargs):
"""
max_retries = 3
max_wait_time = 1000
try:
for i in range(max_retries):
try:
return fn(*args, **kwargs)
except _RpcError as e:
if e.code() == _GrpcStatusCode.UNAUTHENTICATED:
# Always retry auth errors.
if i == (max_retries - 1):
# Exit the loop and wrap the authentication error.
raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e))
cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn(args[0])

for i in range(max_retries):
try:
return fn(*args, **kwargs)
except _RpcError as e:
if e.code() == _GrpcStatusCode.UNAUTHENTICATED:
# Always retry auth errors.
if i == (max_retries - 1):
# Exit the loop and wrap the authentication error.
raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e))
cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn(args[0])
# There are two cases that we should throw error immediately
# 1. Entity already exists when we register entity
# 2. Entity not found when we fetch entity
elif e.code() == _GrpcStatusCode.ALREADY_EXISTS:
raise _user_exceptions.FlyteEntityAlreadyExistsException(e)
elif e.code() == _GrpcStatusCode.NOT_FOUND:
raise _user_exceptions.FlyteEntityNotExistException(e)
else:
# No more retries if retry=False or max_retries reached.
if (retry is False) or i == (max_retries - 1):
raise
else:
# No more retries if retry=False or max_retries reached.
if (retry is False) or i == (max_retries - 1):
raise
else:
# Retry: Start with 200ms wait-time and exponentially back-off up to 1 second.
wait_time = min(200 * (2 ** i), max_wait_time)
cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying")
time.sleep(wait_time / 1000)
except _RpcError as e:
if e.code() == _GrpcStatusCode.ALREADY_EXISTS:
raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e))
else:
raise
# Retry: Start with 200ms wait-time and exponentially back-off up to 1 second.
wait_time = min(200 * (2 ** i), max_wait_time)
cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying")
time.sleep(wait_time / 1000)

return handler

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type:
except ValueError:
logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")

# Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it
# Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
# separately here too.
try:
return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type)
Expand Down
38 changes: 36 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module defining main Flyte backend entrypoint."""
from __future__ import annotations

import logging
import os
import time
import typing
Expand All @@ -15,9 +16,12 @@

from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.common import utils as common_utils
from flytekit.common.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException
from flytekit.configuration import internal
from flytekit.configuration import platform as platform_config
from flytekit.configuration import sdk as sdk_config
from flytekit.configuration import set_flyte_config_file
from flytekit.core import context_manager
from flytekit.core.interface import Interface
from flytekit.loggers import remote_logger
from flytekit.models import filters as filter_models
Expand Down Expand Up @@ -202,7 +206,6 @@ def __init__(
raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.")

self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials)

# read config files, env vars, host, ssl options for admin client
self._flyte_admin_url = flyte_admin_url
self._insecure = insecure
Expand Down Expand Up @@ -520,6 +523,7 @@ def _serialize(
domain or self.default_domain,
version or self.version,
self.image_config,
env={internal.IMAGE.env_var: self.image_config.default_image.full},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a link to flyteorg/flyte#1359 as a comment here?

),
entity=entity,
)
Expand Down Expand Up @@ -604,6 +608,24 @@ def _(
)
return self.fetch_launch_plan(**resolved_identifiers)

def _register_entity_if_not_exists(self, entity: WorkflowBase, resolved_identifiers_dict: dict):
# Try to register all the entity in WorkflowBase including LaunchPlan, PythonTask, or subworkflow.
node_identifiers_dict = deepcopy(resolved_identifiers_dict)
for node in entity.nodes:
try:
node_identifiers_dict["name"] = node.flyte_entity.name
if isinstance(node.flyte_entity, WorkflowBase):
self._register_entity_if_not_exists(node.flyte_entity, node_identifiers_dict)
self.register(node.flyte_entity, **node_identifiers_dict)
elif isinstance(node.flyte_entity, PythonTask) or isinstance(node.flyte_entity, LaunchPlan):
self.register(node.flyte_entity, **node_identifiers_dict)
else:
raise NotImplementedError(f"We don't support registering this kind of entity: {node.flyte_entity}")
except FlyteEntityAlreadyExistsException:
logging.info(f"{entity.name} already exists")
except Exception as e:
logging.info(f"Failed to register entity {entity.name} with error {e}")

####################
# Execute Entities #
####################
Expand Down Expand Up @@ -884,11 +906,23 @@ def _(
"""Execute an @workflow-decorated function."""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)

self._register_entity_if_not_exists(entity, resolved_identifiers_dict)
try:
flyte_workflow: FlyteWorkflow = self.fetch_workflow(**resolved_identifiers_dict)
except Exception:
except FlyteEntityNotExistException:
logging.info("Try to register FlyteWorkflow because it wasn't found in Flyte Admin!")
flyte_workflow: FlyteWorkflow = self.register(entity, **resolved_identifiers_dict)
Copy link
Member Author

@pingsutw pingsutw Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wild-endeavor we will register parent workflow here

flyte_workflow.guessed_python_interface = entity.python_interface

ctx = context_manager.FlyteContext.current_context()
try:
self.fetch_launch_plan(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
logging.info("Try to register default launch plan because it wasn't found in Flyte Admin!")
default_lp = LaunchPlan.get_default_launch_plan(ctx, entity)
self.register(default_lp, **resolved_identifiers_dict)

return self.execute(
flyte_workflow,
inputs,
Expand Down
30 changes: 28 additions & 2 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from flytekit import kwtypes
from flytekit.common.exceptions.user import FlyteAssertion
from flytekit.common.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
from flytekit.core.launch_plan import LaunchPlan
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
from flytekit.remote.remote import FlyteRemote
Expand Down Expand Up @@ -238,8 +238,8 @@ def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_reg

# make sure the task name is the same as the name used during registration
my_wf._name = my_wf.name.replace("mock_flyte_repo.", "")

remote = FlyteRemote.from_config(PROJECT, "development")

xs: typing.List[float] = [42.24, 999.1, 0.0001]
execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True)
assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]"
Expand Down Expand Up @@ -282,3 +282,29 @@ def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_re
output_obj = joblib.load(joblib_output.path)
assert execution.outputs["o0"].extension() == "joblib"
assert output_obj == input_obj


def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):
from mock_flyte_repo.workflows.basic.subworkflows import parent_wf

# make sure the task name is the same as the name used during registration
parent_wf._name = parent_wf.name.replace("mock_flyte_repo.", "")

remote = FlyteRemote.from_config(PROJECT, "development")
execution = remote.execute(parent_wf, {"a": 101}, version=f"v{VERSION}", wait=True)
# check node execution inputs and outputs
assert execution.node_executions["n0"].inputs == {"a": 101}
assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"}
assert execution.node_executions["n1"].inputs == {"a": 103}
assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"}

# check subworkflow task execution inputs and outputs
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}


def test_fetch_not_exist_launch_plan(flyteclient):
remote = FlyteRemote.from_config(PROJECT, "development")
with pytest.raises(FlyteEntityNotExistException):
remote.fetch_launch_plan(name="workflows.basic.list_float_wf.fake_wf", version=f"v{VERSION}")