Skip to content

Commit

Permalink
Enable remote workflow to be invoked in conditional branch (#1890)
Browse files Browse the repository at this point in the history
Signed-off-by: Yue Shang <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent 104b0ca commit 66588f7
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 6 deletions.
53 changes: 47 additions & 6 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import OrderedDict
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from typing import Dict

import click
import fsspec
Expand Down Expand Up @@ -56,9 +57,10 @@
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.common import Sort
from flytekit.models.core import identifier as id_models
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.core.workflow import BranchNode, Node, NodeMetadata
from flytekit.models.execution import (
ClusterAssignment,
ExecutionMetadata,
Expand Down Expand Up @@ -390,15 +392,54 @@ def fetch_workflow(
wf_templates.extend([swf.template for swf in compiled_wf.sub_workflows])

node_launch_plans = {}
# TODO: Inspect branch nodes for launch plans

def find_launch_plan(
lp_ref: id_models, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
if lp_ref not in node_launch_plans:
admin_launch_plan = self.client.get_launch_plan(lp_ref)
node_launch_plans[lp_ref] = admin_launch_plan.spec

for wf_template in wf_templates:
for node in FlyteWorkflow.get_non_system_nodes(wf_template.nodes):
if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None:
lp_ref = node.workflow_node.launchplan_ref
if node.workflow_node.launchplan_ref not in node_launch_plans:
admin_launch_plan = self.client.get_launch_plan(lp_ref)
node_launch_plans[node.workflow_node.launchplan_ref] = admin_launch_plan.spec

find_launch_plan(lp_ref, node_launch_plans)

# Inspect conditional branch nodes for launch plans
def get_launch_plan_from_branch(
branch_node: BranchNode, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
def get_launch_plan_from_then_node(
child_then_node: Node, node_launch_plans: Dict[id_models, launch_plan_models.LaunchPlanSpec]
) -> None:
# then_node could have nested branch_node or be a normal then_node
if child_then_node.branch_node:
get_launch_plan_from_branch(child_then_node.branch_node, node_launch_plans)
elif child_then_node.workflow_node and child_then_node.workflow_node.launchplan_ref:
lp_ref = child_then_node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

if branch_node and branch_node.if_else:
branch = branch_node.if_else
if branch.case and branch.case.then_node:
child_then_node = branch.case.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.other:
for o in branch.other:
if o.then_node:
child_then_node = o.then_node
get_launch_plan_from_then_node(child_then_node, node_launch_plans)
if branch.else_node:
# else_node could have nested conditional branch_node
if branch.else_node.branch_node:
get_launch_plan_from_branch(branch.else_node.branch_node, node_launch_plans)
elif branch.else_node.workflow_node and branch.else_node.workflow_node.launchplan_ref:
lp_ref = branch.else_node.workflow_node.launchplan_ref
find_launch_plan(lp_ref, node_launch_plans)

if node.branch_node:
get_launch_plan_from_branch(node.branch_node, node_launch_plans)
return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)

def fetch_launch_plan(
Expand Down
83 changes: 83 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from flytekit.models import common as common_models
from flytekit.models import security
from flytekit.models.admin.workflow import Workflow, WorkflowClosure
from flytekit.models.core import condition as _condition
from flytekit.models.core import workflow as _workflow
from flytekit.models.core.compiler import CompiledWorkflowClosure
from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier
from flytekit.models.execution import Execution
Expand Down Expand Up @@ -53,6 +55,52 @@
}


obj = _workflow.Node(
id="some:node:id",
metadata="1",
inputs=[],
upstream_node_ids=[],
output_aliases=[],
workflow_node=_workflow.WorkflowNode(launchplan_ref="LAUNCH_PLAN"),
)
node1 = _workflow.Node(
id="some:node:id",
metadata="1",
inputs=[],
upstream_node_ids=[],
output_aliases=[],
branch_node=_workflow.BranchNode(
_workflow.IfElseBlock(
case=_workflow.IfBlock(
condition=_condition.BooleanExpression(),
then_node=obj,
)
)
),
)
nodes = [node1]

obj2 = _workflow.Node(id="some:node:id", metadata="1", inputs=[], upstream_node_ids=[], output_aliases=[])

node2 = node1 = _workflow.Node(
id="some:node:id",
metadata="1",
inputs=[],
upstream_node_ids=[],
output_aliases=[],
branch_node=_workflow.BranchNode(
_workflow.IfElseBlock(
case=_workflow.IfBlock(
condition=_condition.BooleanExpression(),
then_node=obj2,
),
else_node=node1,
)
),
)
nodes2 = [node2]


@pytest.fixture
def remote():
with patch("flytekit.clients.friendly.SynchronousFlyteClient") as mock_client:
Expand Down Expand Up @@ -387,6 +435,41 @@ def test_launch_backfill(remote):
assert wf.workflow_metadata.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY


@patch("flytekit.remote.entities.FlyteWorkflow.get_non_system_nodes", return_value=nodes)
@patch("flytekit.remote.entities.FlyteWorkflow.promote_from_closure")
def test_fetch_workflow_with_branch(mock_promote, mock_workflow, remote):
mock_client = remote._client
mock_client.get_workflow.return_value = Workflow(
id=Identifier(ResourceType.TASK, "p", "d", "n", "v"),
closure=WorkflowClosure(compiled_workflow=MagicMock()),
)

admin_launch_plan = MagicMock()
admin_launch_plan.spec = {"workflow_id": 123}
mock_client.get_launch_plan.return_value = admin_launch_plan
node_launch_plans = {"LAUNCH_PLAN": {"workflow_id": 123}}

remote.fetch_workflow(name="n", version="v")
mock_promote.assert_called_with(ANY, node_launch_plans)


@patch("flytekit.remote.entities.FlyteWorkflow.get_non_system_nodes", return_value=nodes2)
@patch("flytekit.remote.entities.FlyteWorkflow.promote_from_closure")
def test_fetch_workflow_with_nested_branch(mock_promote, mock_workflow, remote):
mock_client = remote._client
mock_client.get_workflow.return_value = Workflow(
id=Identifier(ResourceType.TASK, "p", "d", "n", "v"),
closure=WorkflowClosure(compiled_workflow=MagicMock()),
)
admin_launch_plan = MagicMock()
admin_launch_plan.spec = {"workflow_id": 123}
mock_client.get_launch_plan.return_value = admin_launch_plan
node_launch_plans = {"LAUNCH_PLAN": {"workflow_id": 123}}

remote.fetch_workflow(name="n", version="v")
mock_promote.assert_called_with(ANY, node_launch_plans)


@mock.patch("pathlib.Path.read_bytes")
@mock.patch("flytekit.remote.remote.FlyteRemote._version_from_hash")
@mock.patch("flytekit.remote.remote.FlyteRemote.register_workflow")
Expand Down

0 comments on commit 66588f7

Please sign in to comment.