Skip to content

Commit

Permalink
Create default launch plan when executing WorkflowBase (#707)
Browse files Browse the repository at this point in the history
* Create default lauchplan

Signed-off-by: Kevin Su <[email protected]>

* Update comment

Signed-off-by: Kevin Su <[email protected]>

* Added test

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Register subworkflow, launchplan node

Signed-off-by: Kevin Su <[email protected]>

* Fixed lint

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Fixed test

Signed-off-by: Kevin Su <[email protected]>

* Add link

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Oct 28, 2021
1 parent 88b590c commit 87131a0
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 31 deletions.
54 changes: 28 additions & 26 deletions flytekit/clients/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,33 +125,35 @@ def handler(*args, **kwargs):
"""
max_retries = 3
max_wait_time = 1000
try:
for i in range(max_retries):
try:
return fn(*args, **kwargs)
except _RpcError as e:
if e.code() == _GrpcStatusCode.UNAUTHENTICATED:
# Always retry auth errors.
if i == (max_retries - 1):
# Exit the loop and wrap the authentication error.
raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e))
cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn(args[0])

for i in range(max_retries):
try:
return fn(*args, **kwargs)
except _RpcError as e:
if e.code() == _GrpcStatusCode.UNAUTHENTICATED:
# Always retry auth errors.
if i == (max_retries - 1):
# Exit the loop and wrap the authentication error.
raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e))
cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n")
refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get())
refresh_handler_fn(args[0])
# There are two cases that we should throw error immediately
# 1. Entity already exists when we register entity
# 2. Entity not found when we fetch entity
elif e.code() == _GrpcStatusCode.ALREADY_EXISTS:
raise _user_exceptions.FlyteEntityAlreadyExistsException(e)
elif e.code() == _GrpcStatusCode.NOT_FOUND:
raise _user_exceptions.FlyteEntityNotExistException(e)
else:
# No more retries if retry=False or max_retries reached.
if (retry is False) or i == (max_retries - 1):
raise
else:
# No more retries if retry=False or max_retries reached.
if (retry is False) or i == (max_retries - 1):
raise
else:
# Retry: Start with 200ms wait-time and exponentially back-off up to 1 second.
wait_time = min(200 * (2 ** i), max_wait_time)
cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying")
time.sleep(wait_time / 1000)
except _RpcError as e:
if e.code() == _GrpcStatusCode.ALREADY_EXISTS:
raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e))
else:
raise
# Retry: Start with 200ms wait-time and exponentially back-off up to 1 second.
wait_time = min(200 * (2 ** i), max_wait_time)
cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying")
time.sleep(wait_time / 1000)

return handler

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type:
except ValueError:
logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}")

# Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it
# Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it
# separately here too.
try:
return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type)
Expand Down
39 changes: 37 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module defining main Flyte backend entrypoint."""
from __future__ import annotations

import logging
import os
import time
import typing
Expand All @@ -17,9 +18,12 @@
import flytekit.models.admin.launch_plan
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.common import utils as common_utils
from flytekit.common.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException
from flytekit.configuration import internal
from flytekit.configuration import platform as platform_config
from flytekit.configuration import sdk as sdk_config
from flytekit.configuration import set_flyte_config_file
from flytekit.core import context_manager
from flytekit.core.interface import Interface
from flytekit.loggers import remote_logger
from flytekit.models import filters as filter_models
Expand Down Expand Up @@ -204,7 +208,6 @@ def __init__(
raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.")

self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials)

# read config files, env vars, host, ssl options for admin client
self._flyte_admin_url = flyte_admin_url
self._insecure = insecure
Expand Down Expand Up @@ -522,6 +525,8 @@ def _serialize(
domain or self.default_domain,
version or self.version,
self.image_config,
# https://github.com/flyteorg/flyte/issues/1359
env={internal.IMAGE.env_var: self.image_config.default_image.full},
),
entity=entity,
)
Expand Down Expand Up @@ -606,6 +611,24 @@ def _(
)
return self.fetch_launch_plan(**resolved_identifiers)

def _register_entity_if_not_exists(self, entity: WorkflowBase, resolved_identifiers_dict: dict):
# Try to register all the entity in WorkflowBase including LaunchPlan, PythonTask, or subworkflow.
node_identifiers_dict = deepcopy(resolved_identifiers_dict)
for node in entity.nodes:
try:
node_identifiers_dict["name"] = node.flyte_entity.name
if isinstance(node.flyte_entity, WorkflowBase):
self._register_entity_if_not_exists(node.flyte_entity, node_identifiers_dict)
self.register(node.flyte_entity, **node_identifiers_dict)
elif isinstance(node.flyte_entity, PythonTask) or isinstance(node.flyte_entity, LaunchPlan):
self.register(node.flyte_entity, **node_identifiers_dict)
else:
raise NotImplementedError(f"We don't support registering this kind of entity: {node.flyte_entity}")
except FlyteEntityAlreadyExistsException:
logging.info(f"{entity.name} already exists")
except Exception as e:
logging.info(f"Failed to register entity {entity.name} with error {e}")

####################
# Execute Entities #
####################
Expand Down Expand Up @@ -886,11 +909,23 @@ def _(
"""Execute an @workflow-decorated function."""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)

try:
flyte_workflow: FlyteWorkflow = self.fetch_workflow(**resolved_identifiers_dict)
except Exception:
except FlyteEntityNotExistException:
logging.info("Try to register FlyteWorkflow because it wasn't found in Flyte Admin!")
self._register_entity_if_not_exists(entity, resolved_identifiers_dict)
flyte_workflow: FlyteWorkflow = self.register(entity, **resolved_identifiers_dict)
flyte_workflow.guessed_python_interface = entity.python_interface

ctx = context_manager.FlyteContext.current_context()
try:
self.fetch_launch_plan(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
logging.info("Try to register default launch plan because it wasn't found in Flyte Admin!")
default_lp = LaunchPlan.get_default_launch_plan(ctx, entity)
self.register(default_lp, **resolved_identifiers_dict)

return self.execute(
flyte_workflow,
inputs,
Expand Down
30 changes: 28 additions & 2 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

from flytekit import kwtypes
from flytekit.common.exceptions.user import FlyteAssertion
from flytekit.common.exceptions.user import FlyteAssertion, FlyteEntityNotExistException
from flytekit.core.launch_plan import LaunchPlan
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task
from flytekit.remote.remote import FlyteRemote
Expand Down Expand Up @@ -238,8 +238,8 @@ def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_reg

# make sure the task name is the same as the name used during registration
my_wf._name = my_wf.name.replace("mock_flyte_repo.", "")

remote = FlyteRemote.from_config(PROJECT, "development")

xs: typing.List[float] = [42.24, 999.1, 0.0001]
execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True)
assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]"
Expand Down Expand Up @@ -282,3 +282,29 @@ def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_re
output_obj = joblib.load(joblib_output.path)
assert execution.outputs["o0"].extension() == "joblib"
assert output_obj == input_obj


def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):
from mock_flyte_repo.workflows.basic.subworkflows import parent_wf

# make sure the task name is the same as the name used during registration
parent_wf._name = parent_wf.name.replace("mock_flyte_repo.", "")

remote = FlyteRemote.from_config(PROJECT, "development")
execution = remote.execute(parent_wf, {"a": 101}, version=f"v{VERSION}", wait=True)
# check node execution inputs and outputs
assert execution.node_executions["n0"].inputs == {"a": 101}
assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"}
assert execution.node_executions["n1"].inputs == {"a": 103}
assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"}

# check subworkflow task execution inputs and outputs
subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions
subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103}
subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"}


def test_fetch_not_exist_launch_plan(flyteclient):
remote = FlyteRemote.from_config(PROJECT, "development")
with pytest.raises(FlyteEntityNotExistException):
remote.fetch_launch_plan(name="workflows.basic.list_float_wf.fake_wf", version=f"v{VERSION}")

0 comments on commit 87131a0

Please sign in to comment.