Skip to content

Commit

Permalink
Remote fetch array node (flyteorg#2442)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored May 28, 2024
1 parent 1a01368 commit 5c9cb97
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 62 deletions.
72 changes: 19 additions & 53 deletions flytekit/models/node_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing
from datetime import timezone as _timezone

import flyteidl.admin.node_execution_pb2 as _node_execution_pb2
import flyteidl.admin.node_execution_pb2 as admin_node_execution_pb2

from flytekit.models import common as _common_models
from flytekit.models.core import catalog as catalog_models
Expand All @@ -19,13 +19,13 @@ def __init__(self, execution_id: _identifier.WorkflowExecutionIdentifier):
def execution_id(self) -> _identifier.WorkflowExecutionIdentifier:
return self._execution_id

def to_flyte_idl(self) -> _node_execution_pb2.WorkflowNodeMetadata:
return _node_execution_pb2.WorkflowNodeMetadata(
def to_flyte_idl(self) -> admin_node_execution_pb2.WorkflowNodeMetadata:
return admin_node_execution_pb2.WorkflowNodeMetadata(
executionId=self.execution_id.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata":
def from_flyte_idl(cls, p: admin_node_execution_pb2.WorkflowNodeMetadata) -> "WorkflowNodeMetadata":
return cls(
execution_id=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(p.executionId),
)
Expand All @@ -44,14 +44,14 @@ def id(self) -> _identifier.Identifier:
def compiled_workflow(self) -> core_compiler_models.CompiledWorkflowClosure:
return self._compiled_workflow

def to_flyte_idl(self) -> _node_execution_pb2.DynamicWorkflowNodeMetadata:
return _node_execution_pb2.DynamicWorkflowNodeMetadata(
def to_flyte_idl(self) -> admin_node_execution_pb2.DynamicWorkflowNodeMetadata:
return admin_node_execution_pb2.DynamicWorkflowNodeMetadata(
id=self.id.to_flyte_idl(),
compiled_workflow=self.compiled_workflow.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata":
def from_flyte_idl(cls, p: admin_node_execution_pb2.DynamicWorkflowNodeMetadata) -> "DynamicWorkflowNodeMetadata":
yy = cls(
id=_identifier.Identifier.from_flyte_idl(p.id),
compiled_workflow=core_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow),
Expand All @@ -72,14 +72,14 @@ def cache_status(self) -> int:
def catalog_key(self) -> catalog_models.CatalogMetadata:
return self._catalog_key

def to_flyte_idl(self) -> _node_execution_pb2.TaskNodeMetadata:
return _node_execution_pb2.TaskNodeMetadata(
def to_flyte_idl(self) -> admin_node_execution_pb2.TaskNodeMetadata:
return admin_node_execution_pb2.TaskNodeMetadata(
cache_status=self.cache_status,
catalog_key=self.catalog_key.to_flyte_idl(),
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata":
def from_flyte_idl(cls, p: admin_node_execution_pb2.TaskNodeMetadata) -> "TaskNodeMetadata":
return cls(
cache_status=p.cache_status,
catalog_key=catalog_models.CatalogMetadata.from_flyte_idl(p.catalog_key),
Expand Down Expand Up @@ -185,7 +185,7 @@ def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.node_execution_pb2.NodeExecutionClosure
"""
obj = _node_execution_pb2.NodeExecutionClosure(
obj = admin_node_execution_pb2.NodeExecutionClosure(
phase=self.phase,
output_uri=self.output_uri,
deck_uri=self.deck_uri,
Expand Down Expand Up @@ -227,47 +227,13 @@ def from_flyte_idl(cls, p):
)


class NodeExecutionMetaData(_common_models.FlyteIdlEntity):
def __init__(self, retry_group: str, is_parent_node: bool, spec_node_id: str):
self._retry_group = retry_group
self._is_parent_node = is_parent_node
self._spec_node_id = spec_node_id

@property
def retry_group(self) -> str:
return self._retry_group

@property
def is_parent_node(self) -> bool:
return self._is_parent_node

@property
def spec_node_id(self) -> str:
return self._spec_node_id

def to_flyte_idl(self) -> _node_execution_pb2.NodeExecutionMetaData:
return _node_execution_pb2.NodeExecutionMetaData(
retry_group=self.retry_group,
is_parent_node=self.is_parent_node,
spec_node_id=self.spec_node_id,
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecutionMetaData) -> "NodeExecutionMetaData":
return cls(
retry_group=p.retry_group,
is_parent_node=p.is_parent_node,
spec_node_id=p.spec_node_id,
)


class NodeExecution(_common_models.FlyteIdlEntity):
def __init__(self, id, input_uri, closure, metadata):
def __init__(self, id, input_uri, closure, metadata: admin_node_execution_pb2.NodeExecutionMetaData):
"""
:param flytekit.models.core.identifier.NodeExecutionIdentifier id:
:param Text input_uri:
:param NodeExecutionClosure closure:
:param NodeExecutionMetaData metadata:
:param metadata:
"""
self._id = id
self._input_uri = input_uri
Expand Down Expand Up @@ -296,22 +262,22 @@ def closure(self):
return self._closure

@property
def metadata(self) -> NodeExecutionMetaData:
def metadata(self) -> admin_node_execution_pb2.NodeExecutionMetaData:
return self._metadata

def to_flyte_idl(self) -> _node_execution_pb2.NodeExecution:
return _node_execution_pb2.NodeExecution(
def to_flyte_idl(self) -> admin_node_execution_pb2.NodeExecution:
return admin_node_execution_pb2.NodeExecution(
id=self.id.to_flyte_idl(),
input_uri=self.input_uri,
closure=self.closure.to_flyte_idl(),
metadata=self.metadata.to_flyte_idl(),
metadata=self.metadata,
)

@classmethod
def from_flyte_idl(cls, p: _node_execution_pb2.NodeExecution) -> "NodeExecution":
def from_flyte_idl(cls, p: admin_node_execution_pb2.NodeExecution) -> "NodeExecution":
return cls(
id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id),
input_uri=p.input_uri,
closure=NodeExecutionClosure.from_flyte_idl(p.closure),
metadata=NodeExecutionMetaData.from_flyte_idl(p.metadata),
metadata=p.metadata,
)
7 changes: 6 additions & 1 deletion flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,12 @@ def promote_from_model(cls, model: _workflow_model.GateNode):
class FlyteArrayNode(_workflow_model.ArrayNode):
@classmethod
def promote_from_model(cls, model: _workflow_model.ArrayNode):
return cls(model._parallelism, model._node, model._min_success_ratio, model._min_successes)
return cls(
node=model._node,
parallelism=model._parallelism,
min_successes=model._min_successes,
min_success_ratio=model._min_success_ratio,
)


class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node):
Expand Down
4 changes: 3 additions & 1 deletion flytekit/remote/executions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from abc import abstractmethod
from typing import Dict, List, Optional, Union

Expand All @@ -9,6 +10,7 @@
from flytekit.models import node_execution as node_execution_models
from flytekit.models.admin import task_execution as admin_task_execution_models
from flytekit.models.core import execution as core_execution_models
from flytekit.models.interface import TypedInterface
from flytekit.remote.entities import FlyteTask, FlyteWorkflow


Expand Down Expand Up @@ -148,7 +150,7 @@ def __init__(self, *args, **kwargs):
self._task_executions = None
self._workflow_executions = []
self._underlying_node_executions = None
self._interface = None
self._interface: typing.Optional[TypedInterface] = None
self._flyte_node = None

@property
Expand Down
15 changes: 14 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ def sync_node_execution(
return execution

# If a node ran a static subworkflow or a dynamic subworkflow then the parent flag will be set.
if execution.metadata.is_parent_node:
if execution.metadata.is_parent_node or execution.metadata.is_array:
# We'll need to query child node executions regardless since this is a parent node
child_node_executions = iterate_node_executions(
self.client,
Expand Down Expand Up @@ -2115,6 +2115,19 @@ def sync_node_execution(
"not have inputs and outputs filled in"
)
return execution
elif execution._node.array_node is not None:
# if there's a task node underneath the array node, let's fetch the interface for it
if execution._node.array_node.node.task_node is not None:
tid = execution._node.array_node.node.task_node.reference_id
t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version)
if t.interface:
execution._interface = t.interface
else:
logger.error(f"Fetched map task does not have an interface, skipping i/o {t}")
return execution
else:
logger.error(f"Array node not over task, skipping i/o {t}")
return execution
else:
logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}")
raise Exception(f"Node execution undeterminable, entity has type {type(execution._node)}")
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,16 @@ def my_wf(a: int, b: str) -> (int, str):
assert execution.spec.envs.envs == {"foo": "bar"}
assert execution.spec.tags == ["flyte"]
assert execution.spec.cluster_assignment.cluster_pool == "gpu"


def test_execute_workflow_with_maptask(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
d: typing.List[int] = [1, 2, 3]
flyte_launch_plan = remote.fetch_launch_plan(name="basic.array_map.workflow_with_maptask", version=VERSION)
execution = remote.execute(
flyte_launch_plan,
inputs={"data": d, "y": 3},
version=VERSION,
wait=True,
)
assert execution.outputs["o0"] == [4, 5, 6]
14 changes: 14 additions & 0 deletions tests/flytekit/integration/remote/workflows/basic/array_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from functools import partial

from flytekit import map_task, task, workflow


@task
def fn(x: int, y: int) -> int:
return x + y


@workflow
def workflow_with_maptask(data: list[int], y: int) -> list[int]:
partial_fn = partial(fn, y=y)
return map_task(partial_fn)(x=data)
6 changes: 0 additions & 6 deletions tests/flytekit/unit/models/admin/test_node_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
from tests.flytekit.unit.common_tests.test_workflow_promote import get_compiled_workflow_closure


def test_metadata():
md = node_execution_models.NodeExecutionMetaData(retry_group="0", is_parent_node=True, spec_node_id="n0")
md2 = node_execution_models.NodeExecutionMetaData.from_flyte_idl(md.to_flyte_idl())
assert md == md2


def test_workflow_node_metadata():
wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name")

Expand Down

0 comments on commit 5c9cb97

Please sign in to comment.