diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index bf4c1e2d0c..7c4439d83d 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -79,7 +79,6 @@ def handler(self, create_request): cli_logger.error(_MessageToJson(create_request)) cli_logger.error("Details returned from the flyte admin: ") cli_logger.error(e.details) - e.details += "create_request: " + _MessageToJson(create_request) # Re-raise since we're not handling the error here and add the create_request details raise e @@ -260,7 +259,6 @@ def _refresh_credentials_from_command(self): :param self: RawSynchronousFlyteClient :return: """ - command = self._cfg.command if not command: raise FlyteAuthenticationException("No command specified in configuration for command authentication") diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 1556e343bf..afb6d613fe 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -107,6 +107,12 @@ is_flag=True, help="Enables to skip zipping and uploading the package", ) +@click.option( + "--dry-run", + default=False, + is_flag=True, + help="Execute registration in dry-run mode. Skips actual registration to remote", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -122,6 +128,7 @@ def register( deref_symlinks: bool, non_fast: bool, package_or_module: typing.Tuple[str], + dry_run: bool, ): """ see help @@ -156,6 +163,7 @@ def register( # Create and save FlyteRemote, remote = get_and_save_remote_with_click_context(ctx, project, domain) + click.secho(f"Registering against {remote.config.platform.endpoint}") try: repo.register( project, @@ -170,6 +178,7 @@ def register( fast=not non_fast, package_or_module=package_or_module, remote=remote, + dry_run=dry_run, ) except Exception as e: raise e diff --git a/flytekit/core/node.py b/flytekit/core/node.py index d849ef5397..d8b43f2728 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -51,6 +51,10 @@ def __rshift__(self, other: Node): self.runs_before(other) return other + @property + def name(self) -> str: + return self._id + @property def outputs(self): if self._outputs is None: diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 9f7f84e4bb..53048cb03f 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -870,7 +870,6 @@ def create_and_link_node_from_remote( ) flytekit_node = Node( - # TODO: Better naming, probably a derivative of the function name. id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", metadata=entity.construct_node_metadata(), bindings=sorted(bindings, key=lambda b: b.var), diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index bcb80f34ca..81f6739a39 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -217,12 +217,7 @@ def compile_into_workflow( for entity, model in model_entities.items(): # We only care about gathering tasks here. Launch plans are handled by # propeller. Subworkflows should already be in the workflow spec. - if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskTemplate): - continue - - # Handle FlyteTask - if isinstance(entity, task_models.TaskTemplate): - tts.append(entity) + if not isinstance(entity, Task) and not isinstance(entity, task_models.TaskSpec): continue # We are currently not supporting reference tasks since these will diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 643d613231..174928a5b4 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -85,10 +85,14 @@ """ -from flytekit.remote.component_nodes import FlyteTaskNode, FlyteWorkflowNode +from flytekit.remote.entities import ( + FlyteBranchNode, + FlyteLaunchPlan, + FlyteNode, + FlyteTask, + FlyteTaskNode, + FlyteWorkflow, + FlyteWorkflowNode, +) from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode from flytekit.remote.remote import FlyteRemote -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow diff --git a/flytekit/remote/component_nodes.py b/flytekit/remote/component_nodes.py deleted file mode 100644 index bdf5fab38a..0000000000 --- a/flytekit/remote/component_nodes.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Dict - -from flytekit.exceptions import system as _system_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model - - -class FlyteTaskNode(_workflow_model.TaskNode): - """ - A class encapsulating a task that a Flyte node needs to execute. - """ - - def __init__(self, flyte_task: "flytekit.remote.task.FlyteTask"): - self._flyte_task = flyte_task - super(FlyteTaskNode, self).__init__(None) - - @property - def reference_id(self) -> id_models.Identifier: - """ - A globally unique identifier for the task. - """ - return self._flyte_task.id - - @property - def flyte_task(self) -> "flytekit.remote.tasks.task.FlyteTask": - return self._flyte_task - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.TaskNode, - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteTaskNode": - """ - Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the - FlyteTask control plane. - - :param base_model: - :param tasks: - """ - from flytekit.remote.task import FlyteTask - - if base_model.reference_id in tasks: - task = tasks[base_model.reference_id] - remote_logger.debug(f"Found existing task template for {task.id}, will not retrieve from Admin") - flyte_task = FlyteTask.promote_from_model(task) - return cls(flyte_task) - - raise _system_exceptions.FlyteSystemException(f"Task template {base_model.reference_id} not found.") - - -class FlyteWorkflowNode(_workflow_model.WorkflowNode): - """A class encapsulating a workflow that a Flyte node needs to execute.""" - - def __init__( - self, - flyte_workflow: "flytekit.remote.workflow.FlyteWorkflow" = None, - flyte_launch_plan: "flytekit.remote.launch_plan.FlyteLaunchPlan" = None, - ): - if flyte_workflow and flyte_launch_plan: - raise _system_exceptions.FlyteSystemException( - "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " - f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", - ) - - self._flyte_workflow = flyte_workflow - self._flyte_launch_plan = flyte_launch_plan - super(FlyteWorkflowNode, self).__init__( - launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, - sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, - ) - - def __repr__(self) -> str: - if self.flyte_workflow is not None: - return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" - return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" - - @property - def launchplan_ref(self) -> id_models.Identifier: - """A globally unique identifier for the launch plan, which should map to Admin.""" - return self._flyte_launch_plan.id if self._flyte_launch_plan else None - - @property - def sub_workflow_ref(self): - return self._flyte_workflow.id if self._flyte_workflow else None - - @property - def flyte_launch_plan(self) -> "flytekit.remote.launch_plan.FlyteLaunchPlan": - return self._flyte_launch_plan - - @property - def flyte_workflow(self) -> "flytekit.remote.workflow.FlyteWorkflow": - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.WorkflowNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteWorkflowNode": - from flytekit.remote import launch_plan as _launch_plan - from flytekit.remote import workflow as _workflow - - if base_model.launchplan_ref is not None: - return cls( - flyte_launch_plan=_launch_plan.FlyteLaunchPlan.promote_from_model( - base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] - ) - ) - elif base_model.sub_workflow_ref is not None: - # the workflow templates for sub-workflows should have been included in the original response - if base_model.reference in sub_workflows: - return cls( - flyte_workflow=_workflow.FlyteWorkflow.promote_from_model( - sub_workflows[base_model.reference], - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - ) - raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") - - raise _system_exceptions.FlyteSystemException( - "Bad workflow node model, neither subworkflow nor launchplan specified." - ) - - -class FlyteBranchNode(_workflow_model.BranchNode): - def __init__(self, if_else: _workflow_model.IfElseBlock): - super().__init__(if_else) - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_model.BranchNode, - sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], - node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], - tasks: Dict[id_models.Identifier, _task_model.TaskTemplate], - ) -> "FlyteBranchNode": - - from flytekit.remote.nodes import FlyteNode - - block = base_model.if_else - - else_node = None - if block.else_node: - else_node = FlyteNode.promote_from_model(block.else_node, sub_workflows, node_launch_plans, tasks) - - block.case._then_node = FlyteNode.promote_from_model( - block.case.then_node, sub_workflows, node_launch_plans, tasks - ) - - for o in block.other: - o._then_node = FlyteNode.promote_from_model(o.then_node, sub_workflows, node_launch_plans, tasks) - - new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) - - return cls(new_if_else_block) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py new file mode 100644 index 0000000000..0c745c11bb --- /dev/null +++ b/flytekit/remote/entities.py @@ -0,0 +1,791 @@ +"""This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. +The goal is to enable easy access, manipulation of these entities. """ +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple, Union + +from flytekit.core import constants as _constants +from flytekit.core import hash as _hash_mixin +from flytekit.core import hash as hash_mixin +from flytekit.exceptions import system as _system_exceptions +from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import remote_logger +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_model +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import launch_plan as launch_plan_models +from flytekit.models import task as _task_model +from flytekit.models import task as _task_models +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.core import compiler as compiler_models +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import identifier as id_models +from flytekit.models.core import workflow as _workflow_model +from flytekit.models.core import workflow as _workflow_models +from flytekit.models.core.identifier import Identifier +from flytekit.models.core.workflow import Node, WorkflowMetadata, WorkflowMetadataDefaults +from flytekit.models.interface import TypedInterface +from flytekit.models.literals import Binding +from flytekit.models.task import TaskSpec +from flytekit.remote import interface as _interface +from flytekit.remote import interface as _interfaces +from flytekit.remote.remote_callable import RemoteEntity + + +class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, TaskSpec): + """A class encapsulating a remote Flyte task.""" + + def __init__( + self, + id, + type, + metadata, + interface, + custom, + container=None, + task_type_version: int = 0, + config=None, + should_register: bool = False, + ): + super(FlyteTask, self).__init__( + template=_task_model.TaskTemplate( + id, + type, + metadata, + interface, + custom, + container=container, + task_type_version=task_type_version, + config=config, + ) + ) + self._should_register = should_register + + @property + def id(self): + """ + This is generated by the system and uniquely identifies the task. + :rtype: flytekit.models.core.identifier.Identifier + """ + return self.template.id + + @property + def type(self): + """ + This is used to identify additional extensions for use by Propeller or SDK. + :rtype: Text + """ + return self.template.type + + @property + def metadata(self): + """ + This contains information needed at runtime to determine behavior such as whether or not outputs are + discoverable, timeouts, and retries. + :rtype: TaskMetadata + """ + return self.template.metadata + + @property + def interface(self): + """ + The interface definition for this task. + :rtype: flytekit.models.interface.TypedInterface + """ + return self.template.interface + + @property + def custom(self): + """ + Arbitrary dictionary containing metadata for custom plugins. + :rtype: dict[Text, T] + """ + return self.template.custom + + @property + def task_type_version(self): + return self.template.task_type_version + + @property + def container(self): + """ + If not None, the target of execution should be a container. + :rtype: Container + """ + return self.template.container + + @property + def config(self): + """ + Arbitrary dictionary containing metadata for parsing and handling custom plugins. + :rtype: dict[Text, T] + """ + return self.template.config + + @property + def security_context(self): + return self.template.security_context + + @property + def k8s_pod(self): + return self.template.k8s_pod + + @property + def sql(self): + return self.template.sql + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def name(self) -> str: + return self.template.id.name + + @property + def resource_type(self) -> _identifier_model.ResourceType: + return _identifier_model.ResourceType.TASK + + @property + def entity_type_text(self) -> str: + return "Task" + + @classmethod + def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: + t = cls( + id=base_model.id, + type=base_model.type, + metadata=base_model.metadata, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + custom=base_model.custom, + container=base_model.container, + task_type_version=base_model.task_type_version, + ) + # Override the newly generated name if one exists in the base model + if not base_model.id.is_empty: + t._id = base_model.id + + return t + + +class FlyteTaskNode(_workflow_model.TaskNode): + """ + A class encapsulating a task that a Flyte node needs to execute. + """ + + def __init__(self, flyte_task: FlyteTask): + super(FlyteTaskNode, self).__init__(None) + self._flyte_task = flyte_task + + @property + def reference_id(self) -> id_models.Identifier: + """ + A globally unique identifier for the task. + """ + return self._flyte_task.id + + @property + def flyte_task(self) -> FlyteTask: + return self._flyte_task + + @classmethod + def promote_from_model(cls, task: FlyteTask) -> FlyteTaskNode: + """ + Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the + FlyteTask control plane. + """ + return cls(flyte_task=task) + + +class FlyteWorkflowNode(_workflow_model.WorkflowNode): + """A class encapsulating a workflow that a Flyte node needs to execute.""" + + def __init__( + self, + flyte_workflow: FlyteWorkflow = None, + flyte_launch_plan: FlyteLaunchPlan = None, + ): + if flyte_workflow and flyte_launch_plan: + raise _system_exceptions.FlyteSystemException( + "FlyteWorkflowNode cannot be called with both a workflow and a launchplan specified, please pick " + f"one. workflow: {flyte_workflow} launchPlan: {flyte_launch_plan}", + ) + + self._flyte_workflow = flyte_workflow + self._flyte_launch_plan = flyte_launch_plan + super(FlyteWorkflowNode, self).__init__( + launchplan_ref=self._flyte_launch_plan.id if self._flyte_launch_plan else None, + sub_workflow_ref=self._flyte_workflow.id if self._flyte_workflow else None, + ) + + def __repr__(self) -> str: + if self.flyte_workflow is not None: + return f"FlyteWorkflowNode with workflow: {self.flyte_workflow}" + return f"FlyteWorkflowNode with launch plan: {self.flyte_launch_plan}" + + @property + def launchplan_ref(self) -> id_models.Identifier: + """A globally unique identifier for the launch plan, which should map to Admin.""" + return self._flyte_launch_plan.id if self._flyte_launch_plan else None + + @property + def sub_workflow_ref(self): + return self._flyte_workflow.id if self._flyte_workflow else None + + @property + def flyte_launch_plan(self) -> FlyteLaunchPlan: + return self._flyte_launch_plan + + @property + def flyte_workflow(self) -> FlyteWorkflow: + return self._flyte_workflow + + @classmethod + def _promote_workflow( + cls, + wf: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + return FlyteWorkflow.promote_from_model( + wf, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + if base_model.launchplan_ref is not None: + return ( + cls( + flyte_launch_plan=FlyteLaunchPlan.promote_from_model( + base_model.launchplan_ref, node_launch_plans[base_model.launchplan_ref] + ) + ), + converted_sub_workflows, + ) + elif base_model.sub_workflow_ref is not None: + # the workflow templates for sub-workflows should have been included in the original response + if base_model.reference in sub_workflows: + wf = None + if base_model.reference not in converted_sub_workflows: + wf = cls._promote_workflow( + sub_workflows[base_model.reference], + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + converted_sub_workflows[base_model.reference] = wf + else: + wf = converted_sub_workflows[base_model.reference] + return cls(flyte_workflow=wf), converted_sub_workflows + raise _system_exceptions.FlyteSystemException(f"Subworkflow {base_model.reference} not found.") + + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor launchplan specified." + ) + + +class FlyteBranchNode(_workflow_model.BranchNode): + def __init__(self, if_else: _workflow_model.IfElseBlock): + super().__init__(if_else) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_model.BranchNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteBranchNode, Dict[id_models.Identifier, FlyteWorkflow]]: + + block = base_model.if_else + block.case._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.case.then_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + for o in block.other: + o._then_node, converted_sub_workflows = FlyteNode.promote_from_model( + o.then_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + else_node = None + if block.else_node: + else_node, converted_sub_workflows = FlyteNode.promote_from_model( + block.else_node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + + new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error) + + return cls(new_if_else_block), converted_sub_workflows + + +class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): + """A class encapsulating a remote Flyte node.""" + + def __init__( + self, + id, + upstream_nodes, + bindings, + metadata, + task_node: FlyteTaskNode = None, + workflow_node: FlyteWorkflowNode = None, + branch_node: FlyteBranchNode = None, + ): + if not task_node and not workflow_node and not branch_node: + raise _user_exceptions.FlyteAssertion( + "An Flyte node must have one of task|workflow|branch entity specified at once" + ) + # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from + # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. + if task_node: + self._flyte_entity = task_node.flyte_task + elif workflow_node: + self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan + else: + self._flyte_entity = branch_node + + super(FlyteNode, self).__init__( + id=id, + metadata=metadata, + inputs=bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + task_node=task_node, + workflow_node=workflow_node, + branch_node=branch_node, + ) + self._upstream = upstream_nodes + + @property + def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode]: + return self._flyte_entity + + @classmethod + def _promote_task_node(cls, t: FlyteTask) -> FlyteTaskNode: + return FlyteTaskNode.promote_from_model(t) + + @classmethod + def _promote_workflow_node( + cls, + wn: _workflow_model.WorkflowNode, + sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate], + node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec], + tasks: Dict[Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[FlyteWorkflowNode, Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteWorkflowNode.promote_from_model( + wn, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + + @classmethod + def promote_from_model( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + node_model_id = model.id + # TODO: Consider removing + if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: + remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") + return None, converted_sub_workflows + + flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None + if model.task_node is not None: + if model.task_node.reference_id not in tasks: + raise RuntimeError( + f"Remote Workflow closure does not have task with id {model.task_node.reference_id}." + ) + flyte_task_node = cls._promote_task_node(tasks[model.task_node.reference_id]) + elif model.workflow_node is not None: + flyte_workflow_node, converted_sub_workflows = cls._promote_workflow_node( + model.workflow_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + elif model.branch_node is not None: + flyte_branch_node, converted_sub_workflows = FlyteBranchNode.promote_from_model( + model.branch_node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + else: + raise _system_exceptions.FlyteSystemException( + f"Bad Node model, neither task nor workflow detected, node: {model}" + ) + + # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a + # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. + # TODO: Consider removing + for model_input in model.inputs: + if ( + model_input.binding.promise is not None + and model_input.binding.promise.node_id == _constants.START_NODE_ID + ): + model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID + + return ( + cls( + id=node_model_id, + upstream_nodes=[], # set downstream, model doesn't contain this information + bindings=model.inputs, + metadata=model.metadata, + task_node=flyte_task_node, + workflow_node=flyte_workflow_node, + branch_node=flyte_branch_node, + ), + converted_sub_workflows, + ) + + @property + def upstream_nodes(self) -> List[FlyteNode]: + return self._upstream + + @property + def upstream_node_ids(self) -> List[str]: + return list(sorted(n.id for n in self.upstream_nodes)) + + def __repr__(self) -> str: + return f"Node(ID: {self.id})" + + +class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, WorkflowSpec): + """A class encapsulating a remote Flyte workflow.""" + + def __init__( + self, + id: id_models.Identifier, + nodes: List[FlyteNode], + interface, + output_bindings, + metadata, + metadata_defaults, + subworkflows: Optional[List[FlyteWorkflow]] = None, + tasks: Optional[List[FlyteTask]] = None, + launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, + compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, + should_register: bool = False, + ): + # TODO: Remove check + for node in nodes: + for upstream in node.upstream_nodes: + if upstream.id is None: + raise _user_exceptions.FlyteAssertion( + "Some nodes contained in the workflow were not found in the workflow description. Please " + "ensure all nodes are either assigned to attributes within the class or an element in a " + "list, dict, or tuple which is stored as an attribute in the class." + ) + + self._flyte_sub_workflows = subworkflows + template_subworkflows = [] + if subworkflows: + template_subworkflows = [swf.template for swf in subworkflows] + + super(FlyteWorkflow, self).__init__( + template=_workflow_models.WorkflowTemplate( + id=id, + metadata=metadata, + metadata_defaults=metadata_defaults, + interface=interface, + nodes=nodes, + outputs=output_bindings, + ), + sub_workflows=template_subworkflows, + ) + self._flyte_nodes = nodes + + # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure + self._tasks = tasks + self._launch_plans = launch_plans + self._compiled_closure = compiled_closure + self._node_map = None + self._name = id.name + self._should_register = should_register + + @property + def name(self) -> str: + return self._name + + @property + def flyte_tasks(self) -> Optional[List[FlyteTask]]: + return self._tasks + + @property + def should_register(self) -> bool: + return self._should_register + + @property + def flyte_sub_workflows(self) -> List[FlyteWorkflow]: + return self._flyte_sub_workflows + + @property + def entity_type_text(self) -> str: + return "Workflow" + + @property + def resource_type(self): + return id_models.ResourceType.WORKFLOW + + @property + def flyte_nodes(self) -> List[FlyteNode]: + return self._flyte_nodes + + @property + def id(self) -> Identifier: + """ + This is an autogenerated id by the system. The id is globally unique across Flyte. + """ + return self.template.id + + @property + def metadata(self) -> WorkflowMetadata: + """ + This contains information on how to run the workflow. + """ + return self.template.metadata + + @property + def metadata_defaults(self) -> WorkflowMetadataDefaults: + """ + This contains information on how to run the workflow. + :rtype: WorkflowMetadataDefaults + """ + return self.template.metadata_defaults + + @property + def interface(self) -> TypedInterface: + """ + Defines a strongly typed interface for the Workflow (inputs, outputs). This can include some optional + parameters. + """ + return self.template.interface + + @property + def nodes(self) -> List[Node]: + """ + A list of nodes. In addition, "globals" is a special reserved node id that can be used to consume + workflow inputs + """ + return self.template.nodes + + @property + def outputs(self) -> List[Binding]: + """ + A list of output bindings that specify how to construct workflow outputs. Bindings can + pull node outputs or specify literals. All workflow outputs specified in the interface field must be bound + in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes + to execute successfully in order to bind final outputs. + """ + return self.template.outputs + + @property + def failure_node(self) -> Node: + """ + Node failure_node: A catch-all node. This node is executed whenever the execution engine determines the + workflow has failed. The interface of this node must match the Workflow interface with an additional input + named "error" of type pb.lyft.flyte.core.Error. + """ + return self.template.failure_node + + @classmethod + def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: + return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] + + @classmethod + def _promote_node( + cls, + model: _workflow_model.Node, + sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], + node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], + tasks: Dict[id_models.Identifier, FlyteTask], + converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], + ) -> Tuple[Optional[FlyteNode], Dict[id_models.Identifier, FlyteWorkflow]]: + return FlyteNode.promote_from_model(model, sub_workflows, node_launch_plans, tasks, converted_sub_workflows) + + @classmethod + def promote_from_model( + cls, + base_model: _workflow_models.WorkflowTemplate, + sub_workflows: Optional[Dict[Identifier, _workflow_models.WorkflowTemplate]] = None, + tasks: Optional[Dict[Identifier, FlyteTask]] = None, + node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, + ) -> FlyteWorkflow: + + base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) + + node_map = {} + converted_sub_workflows = {} + for node in base_model_non_system_nodes: + flyte_node, converted_sub_workflows = cls._promote_node( + node, sub_workflows, node_launch_plans, tasks, converted_sub_workflows + ) + node_map[node.id] = flyte_node + + # Set upstream nodes for each node + for n in base_model_non_system_nodes: + current = node_map[n.id] + for upstream_id in n.upstream_node_ids: + upstream_node = node_map[upstream_id] + current._upstream.append(upstream_node) + + subworkflow_list = [] + if converted_sub_workflows: + subworkflow_list = [v for _, v in converted_sub_workflows.items()] + + task_list = [] + if tasks: + task_list = [t for _, t in tasks.items()] + + # No inputs/outputs specified, see the constructor for more information on the overrides. + wf = cls( + id=base_model.id, + nodes=list(node_map.values()), + metadata=base_model.metadata, + metadata_defaults=base_model.metadata_defaults, + interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), + output_bindings=base_model.outputs, + subworkflows=subworkflow_list, + tasks=task_list, + launch_plans=node_launch_plans, + ) + + wf._node_map = node_map + + return wf + + @classmethod + def _promote_task(cls, t: _task_models.TaskTemplate) -> FlyteTask: + return FlyteTask.promote_from_model(t) + + @classmethod + def promote_from_closure( + cls, + closure: compiler_models.CompiledWorkflowClosure, + node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, + ): + """ + Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. + + :param closure: This is the closure returned by Admin + :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. + It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be + :return: + """ + sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} + tasks = {} + if closure.tasks: + tasks = {t.template.id: cls._promote_task(t.template) for t in closure.tasks} + + flyte_wf = cls.promote_from_model( + base_model=closure.primary.template, + sub_workflows=sub_workflows, + node_launch_plans=node_launch_plans, + tasks=tasks, + ) + flyte_wf._compiled_closure = closure + return flyte_wf + + +class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): + """A class encapsulating a remote Flyte launch plan.""" + + def __init__(self, id, *args, **kwargs): + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) + # Set all the attributes we expect this class to have + self._id = id + self._name = id.name + + # The interface is not set explicitly unless fetched in an engine context + self._interface = None + # If fetched when creating this object, can store it here. + self._flyte_workflow = None + + @property + def name(self) -> str: + return self._name + + @property + def flyte_workflow(self) -> Optional[FlyteWorkflow]: + return self._flyte_workflow + + @classmethod + def promote_from_model(cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec) -> FlyteLaunchPlan: + lp = cls( + id=id, + workflow_id=model.workflow_id, + default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), + fixed_inputs=model.fixed_inputs, + entity_metadata=model.entity_metadata, + labels=model.labels, + annotations=model.annotations, + auth_role=model.auth_role, + raw_output_data_config=model.raw_output_data_config, + max_parallelism=model.max_parallelism, + security_context=model.security_context, + ) + return lp + + @property + def id(self) -> id_models.Identifier: + return self._id + + @property + def is_scheduled(self) -> bool: + if self.entity_metadata.schedule.cron_expression: + return True + elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: + return True + elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: + return True + else: + return False + + @property + def workflow_id(self) -> id_models.Identifier: + return self._workflow_id + + @property + def interface(self) -> Optional[_interface.TypedInterface]: + """ + The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and + from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= + object and get a node. + """ + return self._interface + + @property + def resource_type(self) -> id_models.ResourceType: + return id_models.ResourceType.LAUNCH_PLAN + + @property + def entity_type_text(self) -> str: + return "Launch Plan" + + def __repr__(self) -> str: + return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 607b15c889..292b6f0218 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -9,8 +9,7 @@ from flytekit.models import node_execution as node_execution_models from flytekit.models.admin import task_execution as admin_task_execution_models from flytekit.models.core import execution as core_execution_models -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow +from flytekit.remote.entities import FlyteTask, FlyteWorkflow class RemoteExecutionBase(object): diff --git a/flytekit/remote/launch_plan.py b/flytekit/remote/launch_plan.py deleted file mode 100644 index b6c8e1f9e6..0000000000 --- a/flytekit/remote/launch_plan.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from flytekit.core import hash as hash_mixin -from flytekit.models import interface as _interface_models -from flytekit.models import launch_plan as _launch_plan_models -from flytekit.models.core import identifier as id_models -from flytekit.remote import interface as _interface -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_plan_models.LaunchPlanSpec): - """A class encapsulating a remote Flyte launch plan.""" - - def __init__(self, id, *args, **kwargs): - super(FlyteLaunchPlan, self).__init__(*args, **kwargs) - # Set all the attributes we expect this class to have - self._id = id - self._name = id.name - - # The interface is not set explicitly unless fetched in an engine context - self._interface = None - - @property - def name(self) -> str: - return self._name - - # If fetched when creating this object, can store it here. - self._flyte_workflow = None - - @property - def flyte_workflow(self) -> Optional["FlyteWorkflow"]: - return self._flyte_workflow - - @classmethod - def promote_from_model( - cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec - ) -> "FlyteLaunchPlan": - lp = cls( - id=id, - workflow_id=model.workflow_id, - default_inputs=_interface_models.ParameterMap(model.default_inputs.parameters), - fixed_inputs=model.fixed_inputs, - entity_metadata=model.entity_metadata, - labels=model.labels, - annotations=model.annotations, - auth_role=model.auth_role, - raw_output_data_config=model.raw_output_data_config, - max_parallelism=model.max_parallelism, - security_context=model.security_context, - ) - return lp - - @property - def id(self) -> id_models.Identifier: - return self._id - - @property - def is_scheduled(self) -> bool: - if self.entity_metadata.schedule.cron_expression: - return True - elif self.entity_metadata.schedule.rate and self.entity_metadata.schedule.rate.value: - return True - elif self.entity_metadata.schedule.cron_schedule and self.entity_metadata.schedule.cron_schedule.schedule: - return True - else: - return False - - @property - def workflow_id(self) -> id_models.Identifier: - return self._workflow_id - - @property - def interface(self) -> Optional[_interface.TypedInterface]: - """ - The interface is not technically part of the admin.LaunchPlanSpec in the IDL, however the workflow ID is, and - from the workflow ID, fetch will fill in the interface. This is nice because then you can __call__ the= - object and get a node. - """ - return self._interface - - @property - def resource_type(self) -> id_models.ResourceType: - return id_models.ResourceType.LAUNCH_PLAN - - @property - def entity_type_text(self) -> str: - return "Launch Plan" - - def __repr__(self) -> str: - return f"FlyteLaunchPlan(ID: {self.id} Interface: {self.interface}) - Spec {super().__repr__()})" diff --git a/flytekit/remote/lazy_entity.py b/flytekit/remote/lazy_entity.py new file mode 100644 index 0000000000..b40c6e3ff7 --- /dev/null +++ b/flytekit/remote/lazy_entity.py @@ -0,0 +1,62 @@ +import typing +from threading import Lock + +from flytekit import FlyteContext +from flytekit.remote.remote_callable import RemoteEntity + +T = typing.TypeVar("T", bound=RemoteEntity) + + +class LazyEntity(RemoteEntity, typing.Generic[T]): + """ + Fetches the entity when the entity is called or when the entity is retrieved. + The entity is derived from RemoteEntity so that it behaves exactly like the mimiced entity. + """ + + def __init__(self, name: str, getter: typing.Callable[[], T], *args, **kwargs): + super().__init__(*args, **kwargs) + self._entity = None + self._getter = getter + self._name = name + if not self._getter: + raise ValueError("getter method is required to create a Lazy loadable Remote Entity.") + self._mutex = Lock() + + @property + def name(self) -> str: + return self._name + + def entity_fetched(self) -> bool: + with self._mutex: + return self._entity is not None + + @property + def entity(self) -> T: + """ + If not already fetched / available, then the entity will be force fetched. + """ + with self._mutex: + if self._entity is None: + self._entity = self._getter() + return self._entity + + def __getattr__(self, item: str) -> typing.Any: + """ + Forwards all other attributes to entity, causing the entity to be fetched! + """ + return getattr(self.entity, item) + + def compile(self, ctx: FlyteContext, *args, **kwargs): + return self.entity.compile(ctx, *args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Forwards the call to the underlying entity. The entity will be fetched if not already present + """ + return self.entity(*args, **kwargs) + + def __repr__(self) -> str: + return str(self) + + def __str__(self) -> str: + return f"Promise for entity [{self._name}]" diff --git a/flytekit/remote/nodes.py b/flytekit/remote/nodes.py deleted file mode 100644 index 0d73678b7e..0000000000 --- a/flytekit/remote/nodes.py +++ /dev/null @@ -1,164 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional, Union - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.core.promise import NodeOutput -from flytekit.exceptions import system as _system_exceptions -from flytekit.exceptions import user as _user_exceptions -from flytekit.loggers import remote_logger -from flytekit.models import launch_plan as _launch_plan_model -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_model -from flytekit.remote import component_nodes as _component_nodes - - -class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): - """A class encapsulating a remote Flyte node.""" - - def __init__( - self, - id, - upstream_nodes, - bindings, - metadata, - flyte_task: Optional["FlyteTask"] = None, - flyte_workflow: Optional["FlyteWorkflow"] = None, - flyte_launch_plan: Optional["FlyteLaunchPlan"] = None, - flyte_branch_node: Optional["FlyteBranchNode"] = None, - ): - # todo: flyte_branch_node is the only non-entity here, feels wrong, it should probably be a Condition - # or the other ones changed. - non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch_node])) - if len(non_none_entities) != 1: - raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one underlying entity specified at once. Received the following " - "entities: {}".format(non_none_entities) - ) - # todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from - # the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it. - self._flyte_entity = flyte_task or flyte_workflow or flyte_launch_plan or flyte_branch_node - - workflow_node = None - if flyte_workflow is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_workflow=flyte_workflow) - elif flyte_launch_plan is not None: - workflow_node = _component_nodes.FlyteWorkflowNode(flyte_launch_plan=flyte_launch_plan) - - task_node = None - if flyte_task: - task_node = _component_nodes.FlyteTaskNode(flyte_task) - - super(FlyteNode, self).__init__( - id=id, - metadata=metadata, - inputs=bindings, - upstream_node_ids=[n.id for n in upstream_nodes], - output_aliases=[], - task_node=task_node, - workflow_node=workflow_node, - branch_node=flyte_branch_node, - ) - self._upstream = upstream_nodes - - @property - def flyte_entity(self) -> Union["FlyteTask", "FlyteWorkflow", "FlyteLaunchPlan"]: - return self._flyte_entity - - @classmethod - def promote_from_model( - cls, - model: _workflow_model.Node, - sub_workflows: Optional[Dict[id_models.Identifier, _workflow_model.WorkflowTemplate]], - node_launch_plans: Optional[Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec]], - tasks: Optional[Dict[id_models.Identifier, _task_model.TaskTemplate]], - ) -> FlyteNode: - node_model_id = model.id - # TODO: Consider removing - if id in {_constants.START_NODE_ID, _constants.END_NODE_ID}: - remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") - return None - - flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None - if model.task_node is not None: - flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks) - elif model.workflow_node is not None: - flyte_workflow_node = _component_nodes.FlyteWorkflowNode.promote_from_model( - model.workflow_node, - sub_workflows, - node_launch_plans, - tasks, - ) - elif model.branch_node is not None: - flyte_branch_node = _component_nodes.FlyteBranchNode.promote_from_model( - model.branch_node, sub_workflows, node_launch_plans, tasks - ) - else: - raise _system_exceptions.FlyteSystemException( - f"Bad Node model, neither task nor workflow detected, node: {model}" - ) - - # When WorkflowTemplate models (containing node models) are returned by Admin, they've been compiled with a - # start node. In order to make the promoted FlyteWorkflow look the same, we strip the start-node text back out. - # TODO: Consider removing - for model_input in model.inputs: - if ( - model_input.binding.promise is not None - and model_input.binding.promise.node_id == _constants.START_NODE_ID - ): - model_input.binding.promise._node_id = _constants.GLOBAL_INPUT_NODE_ID - - if flyte_task_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_task=flyte_task_node.flyte_task, - ) - elif flyte_workflow_node is not None: - if flyte_workflow_node.flyte_workflow is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_workflow=flyte_workflow_node.flyte_workflow, - ) - elif flyte_workflow_node.flyte_launch_plan is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_launch_plan=flyte_workflow_node.flyte_launch_plan, - ) - raise _system_exceptions.FlyteSystemException( - "Bad FlyteWorkflowNode model, both launch plan and workflow are None" - ) - elif flyte_branch_node is not None: - return cls( - id=node_model_id, - upstream_nodes=[], # set downstream, model doesn't contain this information - bindings=model.inputs, - metadata=model.metadata, - flyte_branch_node=flyte_branch_node, - ) - raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty") - - @property - def upstream_nodes(self) -> List[FlyteNode]: - return self._upstream - - @property - def upstream_node_ids(self) -> List[str]: - return list(sorted(n.id for n in self.upstream_nodes)) - - @property - def outputs(self) -> Dict[str, NodeOutput]: - return self._outputs - - def __repr__(self) -> str: - return f"Node(ID: {self.id})" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 14cd7e11bb..6473d46ec9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -53,13 +53,11 @@ NotificationList, WorkflowExecutionGetDataResponse, ) +from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteWorkflow from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution from flytekit.remote.interface import TypedInterface -from flytekit.remote.launch_plan import FlyteLaunchPlan -from flytekit.remote.nodes import FlyteNode +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity -from flytekit.remote.task import FlyteTask -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.fast_registration import fast_package from flytekit.tools.script_mode import fast_register_single_script, hash_file from flytekit.tools.translator import ( @@ -75,6 +73,14 @@ MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) +class RegistrationSkipped(Exception): + """ + RegistrationSkipped error is raised when trying to register an entity that is not registrable. + """ + + pass + + @dataclass class ResolvedIdentifiers: project: str @@ -190,6 +196,20 @@ def remote_context(self): FlyteContextManager.current_context().with_file_access(self.file_access) ) + def fetch_task_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity: + """ + Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_task(project=project, domain=domain, name=name, version=version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: """Fetch a task entity from flyte admin. @@ -213,14 +233,28 @@ def fetch_task(self, project: str = None, domain: str = None, name: str = None, ) admin_task = self.client.get_task(task_id) flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) - flyte_task._id = task_id + flyte_task.template._id = task_id return flyte_task + def fetch_workflow_lazy( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> LazyEntity[FlyteWorkflow]: + """ + Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily. + """ + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") + + def _fetch(): + return self.fetch_workflow(project, domain, name, version) + + return LazyEntity(name=name, getter=_fetch) + def fetch_workflow( self, project: str = None, domain: str = None, name: str = None, version: str = None ) -> FlyteWorkflow: - """Fetch a workflow entity from flyte admin. - + """ + Fetch a workflow entity from flyte admin. :param project: fetch entity from this project. If None, uses the default_project attribute. :param domain: fetch entity from this domain. If None, uses the default_domain attribute. :param name: fetch entity with matching name. @@ -237,6 +271,7 @@ def fetch_workflow( name, version, ) + admin_workflow = self.client.get_workflow(workflow_id) compiled_wf = admin_workflow.closure.compiled_workflow @@ -359,8 +394,8 @@ def list_tasks_by_version( def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier: ident = Identifier( resource_type=t, - project=ss.project or self.default_project if ss else self.default_project, - domain=ss.domain or self.default_domain if ss else self.default_domain, + project=ss.project if ss and ss.project else self.default_project, + domain=ss.domain if ss and ss.domain else self.default_domain, name=name, version=version or ss.version, ) @@ -374,7 +409,7 @@ def _resolve_identifier(self, t: int, name: str, version: str, ss: Serialization def raw_register( self, cp_entity: FlyteControlPlaneEntity, - settings: typing.Optional[SerializationSettings], + settings: SerializationSettings, version: str, create_default_launchplan: bool = True, options: Options = None, @@ -393,6 +428,15 @@ def raw_register( :param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true :return: Identifier of the created entity """ + if isinstance(cp_entity, RemoteEntity): + if isinstance(cp_entity, (FlyteWorkflow, FlyteTask)): + if not cp_entity.should_register: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + else: + remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}") + raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.") + if isinstance( cp_entity, ( @@ -410,6 +454,8 @@ def raw_register( return None if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, FlyteTask): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings) try: self.client.create_task(task_identifer=ident, task_spec=cp_entity) @@ -418,6 +464,8 @@ def raw_register( return ident if isinstance(cp_entity, admin_workflow_models.WorkflowSpec): + if isinstance(cp_entity, FlyteWorkflow): + version = cp_entity.id.version ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings) try: self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity) @@ -484,10 +532,6 @@ def _serialize_and_register( ident = None for entity, cp_entity in m.items(): - if isinstance(entity, RemoteEntity): - remote_logger.debug(f"Skipping registration of remote entity: {entity.name}") - continue - if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting: # Only in the case of workflows can we use the dummy serialization settings. raise user_exceptions.FlyteValueException( @@ -495,14 +539,17 @@ def _serialize_and_register( f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}", ) - ident = self.raw_register( - cp_entity, - settings=settings, - version=version, - create_default_launchplan=True, - options=options, - og_entity=entity, - ) + try: + ident = self.raw_register( + cp_entity, + settings=settings, + version=version, + create_default_launchplan=True, + options=options, + og_entity=entity, + ) + except RegistrationSkipped: + pass return ident @@ -602,7 +649,7 @@ def _upload_file( filename=to_upload.name, ) self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) - remote_logger.warning( + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index c04ec75f66..9adfd4846f 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -63,10 +63,10 @@ def __call__(self, *args, **kwargs): return self.execute(**kwargs) def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + return self.execute(**kwargs) def execute(self, **kwargs) -> Any: - raise Exception("Remotely fetched entities cannot be run locally. You have to mock this out.") + raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.") @property def python_interface(self) -> Optional[Dict[str, Type]]: diff --git a/flytekit/remote/task.py b/flytekit/remote/task.py deleted file mode 100644 index 3c2c8f8d92..0000000000 --- a/flytekit/remote/task.py +++ /dev/null @@ -1,51 +0,0 @@ -from flytekit.core import hash as hash_mixin -from flytekit.models import task as _task_model -from flytekit.models.core import identifier as _identifier_model -from flytekit.remote import interface as _interfaces -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteTask(hash_mixin.HashOnReferenceMixin, RemoteEntity, _task_model.TaskTemplate): - """A class encapsulating a remote Flyte task.""" - - def __init__(self, id, type, metadata, interface, custom, container=None, task_type_version=0, config=None): - super(FlyteTask, self).__init__( - id, - type, - metadata, - interface, - custom, - container=container, - task_type_version=task_type_version, - config=config, - ) - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def resource_type(self) -> _identifier_model.ResourceType: - return _identifier_model.ResourceType.TASK - - @property - def entity_type_text(self) -> str: - return "Task" - - @classmethod - def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> "FlyteTask": - t = cls( - id=base_model.id, - type=base_model.type, - metadata=base_model.metadata, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - custom=base_model.custom, - container=base_model.container, - task_type_version=base_model.task_type_version, - ) - # Override the newly generated name if one exists in the base model - if not base_model.id.is_empty: - t._id = base_model.id - - return t diff --git a/flytekit/remote/workflow.py b/flytekit/remote/workflow.py deleted file mode 100644 index 3133f8a1fe..0000000000 --- a/flytekit/remote/workflow.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -from typing import Dict, List, Optional - -from flytekit.core import constants as _constants -from flytekit.core import hash as _hash_mixin -from flytekit.exceptions import user as _user_exceptions -from flytekit.models import launch_plan as launch_plan_models -from flytekit.models import task as _task_models -from flytekit.models.core import compiler as compiler_models -from flytekit.models.core import identifier as id_models -from flytekit.models.core import workflow as _workflow_models -from flytekit.remote import interface as _interfaces -from flytekit.remote import nodes as _nodes -from flytekit.remote.remote_callable import RemoteEntity - - -class FlyteWorkflow(_hash_mixin.HashOnReferenceMixin, RemoteEntity, _workflow_models.WorkflowTemplate): - """A class encapsulating a remote Flyte workflow.""" - - def __init__( - self, - id: id_models.Identifier, - nodes: List[_nodes.FlyteNode], - interface, - output_bindings, - metadata, - metadata_defaults, - subworkflows: Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]] = None, - tasks: Optional[Dict[id_models.Identifier, _task_models.TaskTemplate]] = None, - launch_plans: Optional[Dict[id_models.Identifier, launch_plan_models.LaunchPlanSpec]] = None, - compiled_closure: Optional[compiler_models.CompiledWorkflowClosure] = None, - ): - # TODO: Remove check - for node in nodes: - for upstream in node.upstream_nodes: - if upstream.id is None: - raise _user_exceptions.FlyteAssertion( - "Some nodes contained in the workflow were not found in the workflow description. Please " - "ensure all nodes are either assigned to attributes within the class or an element in a " - "list, dict, or tuple which is stored as an attribute in the class." - ) - super(FlyteWorkflow, self).__init__( - id=id, - metadata=metadata, - metadata_defaults=metadata_defaults, - interface=interface, - nodes=nodes, - outputs=output_bindings, - ) - self._flyte_nodes = nodes - - # Optional things that we save for ease of access when promoting from a model or CompiledWorkflowClosure - self._subworkflows = subworkflows - self._tasks = tasks - self._launch_plans = launch_plans - self._compiled_closure = compiled_closure - self._node_map = None - self._name = id.name - - @property - def name(self) -> str: - return self._name - - @property - def sub_workflows(self) -> Optional[Dict[id_models.Identifier, _workflow_models.WorkflowTemplate]]: - return self._subworkflows - - @property - def entity_type_text(self) -> str: - return "Workflow" - - @property - def resource_type(self): - return id_models.ResourceType.WORKFLOW - - @property - def flyte_nodes(self) -> List[_nodes.FlyteNode]: - return self._flyte_nodes - - @classmethod - def get_non_system_nodes(cls, nodes: List[_workflow_models.Node]) -> List[_workflow_models.Node]: - return [n for n in nodes if n.id not in {_constants.START_NODE_ID, _constants.END_NODE_ID}] - - @classmethod - def promote_from_model( - cls, - base_model: _workflow_models.WorkflowTemplate, - sub_workflows: Optional[Dict[id_models, _workflow_models.WorkflowTemplate]] = None, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - tasks: Optional[Dict[id_models, _task_models.TaskTemplate]] = None, - ) -> FlyteWorkflow: - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) - sub_workflows = sub_workflows or {} - tasks = tasks or {} - node_map = { - node.id: _nodes.FlyteNode.promote_from_model(node, sub_workflows, node_launch_plans, tasks) - for node in base_model_non_system_nodes - } - - # Set upstream nodes for each node - for n in base_model_non_system_nodes: - current = node_map[n.id] - for upstream_id in n.upstream_node_ids: - upstream_node = node_map[upstream_id] - current._upstream.append(upstream_node) - - # No inputs/outputs specified, see the constructor for more information on the overrides. - wf = cls( - id=base_model.id, - nodes=list(node_map.values()), - metadata=base_model.metadata, - metadata_defaults=base_model.metadata_defaults, - interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), - output_bindings=base_model.outputs, - subworkflows=sub_workflows, - tasks=tasks, - launch_plans=node_launch_plans, - ) - - wf._node_map = node_map - - return wf - - @classmethod - def promote_from_closure( - cls, - closure: compiler_models.CompiledWorkflowClosure, - node_launch_plans: Optional[Dict[id_models, launch_plan_models.LaunchPlanSpec]] = None, - ): - """ - Extracts out the relevant portions of a FlyteWorkflow from a closure from the control plane. - - :param closure: This is the closure returned by Admin - :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. - It only has subworkflows and tasks. Why this is is unclear. If supplied, this map of launch plans will be - :return: - """ - sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} - tasks = {t.template.id: t.template for t in closure.tasks} - - flyte_wf = FlyteWorkflow.promote_from_model( - base_model=closure.primary.template, - sub_workflows=sub_workflows, - node_launch_plans=node_launch_plans, - tasks=tasks, - ) - flyte_wf._compiled_closure = closure - return flyte_wf diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index 870299e5ad..50bac67844 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -10,7 +10,9 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.loggers import logger from flytekit.models import launch_plan +from flytekit.models.core.identifier import Identifier from flytekit.remote import FlyteRemote +from flytekit.remote.remote import RegistrationSkipped from flytekit.tools import fast_registration, module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities @@ -179,6 +181,27 @@ def load_packages_and_modules( return registrable_entities +def secho(i: Identifier, state: str = "success", reason: str = None): + state_ind = "[ ]" + fg = "white" + nl = False + if state == "success": + state_ind = "\r[✔]" + fg = "green" + nl = True + reason = f"successful with version {i.version}" if not reason else reason + elif state == "failed": + state_ind = "\r[x]" + fg = "red" + nl = True + reason = "skipped!" + click.secho( + click.style(f"{state_ind}", fg=fg) + f" Registration {i.name} type {i.resource_type_name()} {reason}", + dim=True, + nl=nl, + ) + + def register( project: str, domain: str, @@ -192,6 +215,7 @@ def register( fast: bool, package_or_module: typing.Tuple[str], remote: FlyteRemote, + dry_run: bool = False, ): detected_root = find_common_root(package_or_module) click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") @@ -234,11 +258,18 @@ def register( if len(serializable_entities) == 0: click.secho("No Flyte entities were detected. Aborting!", fg="red") return - click.secho(f"Found and serialized {len(serializable_entities)} entities") for cp_entity in serializable_entities: - name = cp_entity.id.name if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id.name - click.secho(f" Registering {name}....", dim=True, nl=False) - i = remote.raw_register(cp_entity, serialization_settings, version=version, create_default_launchplan=False) - click.secho(f"done, {i.resource_type_name()} with version {i.version}.", dim=True) + og_id = cp_entity.id if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id + secho(og_id, "") + try: + if not dry_run: + i = remote.raw_register( + cp_entity, serialization_settings, version=version, create_default_launchplan=False + ) + secho(i) + else: + secho(og_id, reason="Dry run Mode!") + except RegistrationSkipped: + secho(og_id, "failed") click.secho(f"Successfully registered {len(serializable_entities)} entities", fg="green") diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index abea7019f1..f0ad5e96c6 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -22,7 +22,6 @@ from flytekit.models import interface as interface_models from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import security -from flytekit.models import task as task_models from flytekit.models.admin import workflow as admin_workflow_models from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _core_wf @@ -30,6 +29,7 @@ from flytekit.models.core.workflow import ApproveCondition from flytekit.models.core.workflow import BranchNode as BranchNodeModel from flytekit.models.core.workflow import GateNode, SignalCondition, SleepCondition, TaskNodeOverrides +from flytekit.models.task import TaskSpec, TaskTemplate FlyteLocalEntity = Union[ PythonTask, @@ -43,7 +43,7 @@ ReferenceEntity, ] FlyteControlPlaneEntity = Union[ - task_models.TaskSpec, + TaskSpec, _launch_plan_models.LaunchPlan, admin_workflow_models.WorkflowSpec, workflow_model.Node, @@ -154,10 +154,9 @@ def fn(settings: SerializationSettings) -> List[str]: def get_serializable_task( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, -) -> task_models.TaskSpec: +) -> TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, @@ -197,7 +196,7 @@ def get_serializable_task( pod = entity.get_k8s_pod(settings) entity.reset_command_fn() - tt = task_models.TaskTemplate( + tt = TaskTemplate( id=task_id, type=entity.task_type, metadata=entity.metadata.to_taskmetadata_model(), @@ -212,7 +211,7 @@ def get_serializable_task( ) if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask): entity.reset_command_fn() - return task_models.TaskSpec(template=tt) + return TaskSpec(template=tt) def get_serializable_workflow( @@ -221,18 +220,19 @@ def get_serializable_workflow( entity: WorkflowBase, options: Optional[Options] = None, ) -> admin_workflow_models.WorkflowSpec: - # TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.workflow import FlyteWorkflow - - # Get node models - upstream_node_models = [ - get_serializable(entity_mapping, settings, n, options) - for n in entity.nodes - if n.id != _common_constants.GLOBAL_INPUT_NODE_ID - ] - + # Serialize all nodes + serialized_nodes = [] sub_wfs = [] for n in entity.nodes: + # Ignore start nodes + if n.id == _common_constants.GLOBAL_INPUT_NODE_ID: + continue + + # Recursively serialize the node + serialized_nodes.append(get_serializable(entity_mapping, settings, n, options)) + + # If the node is workflow Node or Branch node, we need to handle it specially, to extract all subworkflows, + # so that they can be added to the workflow being serialized if isinstance(n.flyte_entity, WorkflowBase): # We are currently not supporting reference workflows since these will # require a network call to flyteadmin to populate the WorkflowTemplate @@ -249,10 +249,14 @@ def get_serializable_workflow( sub_wfs.append(sub_wf_spec.template) sub_wfs.extend(sub_wf_spec.sub_workflows) + from flytekit.remote import FlyteWorkflow + if isinstance(n.flyte_entity, FlyteWorkflow): - get_serializable(entity_mapping, settings, n.flyte_entity, options) - sub_wfs.append(n.flyte_entity) - sub_wfs.extend([s for s in n.flyte_entity.sub_workflows.values()]) + for swf in n.flyte_entity.flyte_sub_workflows: + sub_wf = get_serializable(entity_mapping, settings, swf, options) + sub_wfs.append(sub_wf.template) + main_wf = get_serializable(entity_mapping, settings, n.flyte_entity, options) + sub_wfs.append(main_wf.template) if isinstance(n.flyte_entity, BranchNode): if_else: workflow_model.IfElseBlock = n.flyte_entity._ifelse_block @@ -288,7 +292,7 @@ def get_serializable_workflow( metadata=entity.workflow_metadata.to_flyte_model(), metadata_defaults=entity.workflow_metadata_defaults.to_flyte_model(), interface=entity.interface, - nodes=upstream_node_models, + nodes=serialized_nodes, outputs=entity.output_bindings, ) return admin_workflow_models.WorkflowSpec( @@ -376,12 +380,7 @@ def get_serializable_node( if entity.flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") - # TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - - upstream_sdk_nodes = [ + upstream_nodes = [ get_serializable(entity_mapping, settings, n, options=options) for n in entity.upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID @@ -395,7 +394,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], ) if ref_template.resource_type == _identifier_model.ResourceType.TASK: @@ -410,13 +409,15 @@ def get_serializable_node( ) return node_model + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources) @@ -431,7 +432,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.template.id), ) @@ -441,7 +442,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], branch_node=get_serializable(entity_mapping, settings, entity.flyte_entity, options=options), ) @@ -459,7 +460,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id), ) @@ -480,7 +481,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], gate_node=gn, ) @@ -492,23 +493,23 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources) ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): - wf_template = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) - for _, sub_wf in entity.flyte_entity.sub_workflows.items(): + wf_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) + for sub_wf in entity.flyte_entity.flyte_sub_workflows: get_serializable(entity_mapping, settings, sub_wf, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=entity.bindings, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], - workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_template.id), + workflow_node=workflow_model.WorkflowNode(sub_workflow_ref=wf_spec.id), ) elif isinstance(entity.flyte_entity, FlyteLaunchPlan): # Recursive call doesn't do anything except put the entity on the map. @@ -523,7 +524,7 @@ def get_serializable_node( id=_dnsify(entity.id), metadata=entity.metadata, inputs=node_input, - upstream_node_ids=[n.id for n in upstream_sdk_nodes], + upstream_node_ids=[n.id for n in upstream_nodes], output_aliases=[], workflow_node=workflow_model.WorkflowNode(launchplan_ref=entity.flyte_entity.id), ) @@ -563,6 +564,54 @@ def get_reference_spec( return ReferenceSpec(template) +def get_serializable_flyte_workflow( + entity: "FlyteWorkflow", settings: SerializationSettings +) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + + def _mutate_task_node(tn: workflow_model.TaskNode): + tn.reference_id._project = settings.project + tn.reference_id._domain = settings.domain + + def _mutate_branch_node_task_ids(bn: workflow_model.BranchNode): + _mutate_node(bn.if_else.case.then_node) + for c in bn.if_else.other: + _mutate_node(c.then_node) + if bn.if_else.else_node: + _mutate_node(bn.if_else.else_node) + + def _mutate_workflow_node(wn: workflow_model.WorkflowNode): + wn.sub_workflow_ref._project = settings.project + wn.sub_workflow_ref._domain = settings.domain + + def _mutate_node(n: workflow_model.Node): + if n.task_node: + _mutate_task_node(n.task_node) + elif n.branch_node: + _mutate_branch_node_task_ids(n.branch_node) + elif n.workflow_node: + _mutate_workflow_node(n.workflow_node) + + for n in entity.flyte_nodes: + _mutate_node(n) + + entity.id._project = settings.project + entity.id._domain = settings.domain + + return entity + + +def get_serializable_flyte_task(entity: "FlyteTask", settings: SerializationSettings) -> FlyteControlPlaneEntity: + """ + TODO replace with deep copy + """ + entity.id._project = settings.project + entity.id._domain = settings.domain + return entity + + def get_serializable( entity_mapping: OrderedDict, settings: SerializationSettings, @@ -586,19 +635,16 @@ def get_serializable( :return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter is also returned. """ - # TODO: Try to replace following config refactor - https://github.com/flyteorg/flyte/issues/2214 - from flytekit.remote.launch_plan import FlyteLaunchPlan - from flytekit.remote.task import FlyteTask - from flytekit.remote.workflow import FlyteWorkflow - if entity in entity_mapping: return entity_mapping[entity] + from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow + if isinstance(entity, ReferenceEntity): cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(entity_mapping, settings, entity) + cp_entity = get_serializable_task(settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) @@ -612,7 +658,21 @@ def get_serializable( elif isinstance(entity, BranchNode): cp_entity = get_serializable_branch_node(entity_mapping, settings, entity, options) - elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow) or isinstance(entity, FlyteLaunchPlan): + elif isinstance(entity, FlyteTask) or isinstance(entity, FlyteWorkflow): + if entity.should_register: + if isinstance(entity, FlyteTask): + cp_entity = get_serializable_flyte_task(entity, settings) + else: + if entity.should_register: + # We only add the tasks if the should register flag is set. This is to avoid adding + # unnecessary tasks to the registrable list. + for t in entity.flyte_tasks: + get_serializable(entity_mapping, settings, t, options) + cp_entity = get_serializable_flyte_workflow(entity, settings) + else: + cp_entity = entity + + elif isinstance(entity, FlyteLaunchPlan): cp_entity = entity else: @@ -626,7 +686,7 @@ def get_serializable( def gather_dependent_entities( serialized: OrderedDict, ) -> Tuple[ - Dict[_identifier_model.Identifier, task_models.TaskTemplate], + Dict[_identifier_model.Identifier, TaskTemplate], Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec], Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec], ]: @@ -639,12 +699,12 @@ def gather_dependent_entities( :param serialized: This should be the filled in OrderedDict used in the get_serializable function above. :return: """ - task_templates: Dict[_identifier_model.Identifier, task_models.TaskTemplate] = {} + task_templates: Dict[_identifier_model.Identifier, TaskTemplate] = {} workflow_specs: Dict[_identifier_model.Identifier, admin_workflow_models.WorkflowSpec] = {} launch_plan_specs: Dict[_identifier_model.Identifier, _launch_plan_models.LaunchPlanSpec] = {} for cp_entity in serialized.values(): - if isinstance(cp_entity, task_models.TaskSpec): + if isinstance(cp_entity, TaskSpec): task_templates[cp_entity.template.id] = cp_entity.template elif isinstance(cp_entity, _launch_plan_models.LaunchPlan): launch_plan_specs[cp_entity.id] = cp_entity.spec diff --git a/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb b/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb new file mode 100644 index 0000000000..1f3ce5c79a Binary files /dev/null and b/tests/flytekit/unit/remote/responses/CompiledWorkflowClosure.pb differ diff --git a/tests/flytekit/unit/remote/test_calling.py b/tests/flytekit/unit/remote/test_calling.py index 00d80464c3..34e4f8e8b8 100644 --- a/tests/flytekit/unit/remote/test_calling.py +++ b/tests/flytekit/unit/remote/test_calling.py @@ -12,11 +12,10 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion -from flytekit.models.core.workflow import WorkflowTemplate -from flytekit.models.task import TaskTemplate -from flytekit.remote import FlyteLaunchPlan, FlyteTask +from flytekit.models.admin.workflow import WorkflowSpec +from flytekit.models.task import TaskSpec +from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow from flytekit.remote.interface import TypedInterface -from flytekit.remote.workflow import FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -63,7 +62,7 @@ def wf(a: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) vals = [v for v in serialized.values()] - tts = [f for f in filter(lambda x: isinstance(x, TaskTemplate), vals)] + tts = [f for f in filter(lambda x: isinstance(x, TaskSpec), vals)] assert len(tts) == 1 assert wf_spec.template.nodes[0].id == "foobar" assert wf_spec.template.outputs[0].binding.promise.node_id == "foobar" @@ -143,9 +142,11 @@ def my_subwf(a: int) -> typing.List[int]: def test_calling_wf(): # No way to fetch from Admin in unit tests so we serialize and then promote back serialized = OrderedDict() - wf_spec = get_serializable(serialized, serialization_settings, sub_wf) + wf_spec: WorkflowSpec = get_serializable(serialized, serialization_settings, sub_wf) task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) - fwf = FlyteWorkflow.promote_from_model(wf_spec.template, tasks=task_templates) + fwf = FlyteWorkflow.promote_from_model( + wf_spec.template, tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()} + ) @workflow def parent_1(a: int, b: str) -> typing.Tuple[int, str]: @@ -162,8 +163,14 @@ def parent_1(a: int, b: str) -> typing.Tuple[int, str]: # Pick out the subworkflow templates from the ordereddict. We can't use the output of the gather_dependent_entities # function because that only looks for WorkflowSpecs - subwf_templates = {x.id: x for x in list(filter(lambda x: isinstance(x, WorkflowTemplate), serialized.values()))} - fwf_p1 = FlyteWorkflow.promote_from_model(wf_spec.template, sub_workflows=subwf_templates, tasks=task_templates_p1) + subwf_templates = { + x.template.id: x.template for x in list(filter(lambda x: isinstance(x, WorkflowSpec), serialized.values())) + } + fwf_p1 = FlyteWorkflow.promote_from_model( + wf_spec.template, + sub_workflows=subwf_templates, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates_p1.items()}, + ) @workflow def parent_2(a: int, b: str) -> typing.Tuple[int, str]: diff --git a/tests/flytekit/unit/remote/test_lazy_entity.py b/tests/flytekit/unit/remote/test_lazy_entity.py new file mode 100644 index 0000000000..1ed191aea4 --- /dev/null +++ b/tests/flytekit/unit/remote/test_lazy_entity.py @@ -0,0 +1,65 @@ +import pytest +from mock import patch + +from flytekit import TaskMetadata +from flytekit.core import context_manager +from flytekit.models.core.identifier import Identifier, ResourceType +from flytekit.models.interface import TypedInterface +from flytekit.remote import FlyteTask +from flytekit.remote.lazy_entity import LazyEntity + + +def test_missing_getter(): + with pytest.raises(ValueError): + LazyEntity("x", None) + + +dummy_task = FlyteTask( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + type="t", + metadata=TaskMetadata().to_taskmetadata_model(), + interface=TypedInterface(inputs={}, outputs={}), + custom=None, +) + + +def test_lazy_loading(): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.__repr__() == "Promise for entity [x]" + assert e.name == "x" + assert e._entity is None + assert not e.entity_fetched() + v = e.entity + assert e._entity is not None + assert v == dummy_task + assert e.entity == dummy_task + assert e.entity_fetched() + + +@patch("flytekit.remote.remote_callable.create_and_link_node_from_remote") +def test_lazy_loading_compile(create_and_link_node_from_remote_mock): + once = True + + def _getter(): + nonlocal once + if not once: + raise ValueError("Should be called once only") + once = False + return dummy_task + + e = LazyEntity("x", _getter) + assert e.name == "x" + assert e._entity is None + ctx = context_manager.FlyteContext.current_context() + e.compile(ctx) + assert e._entity is not None + assert e.entity == dummy_task diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index dd37b97f87..01688ea825 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -3,6 +3,7 @@ import tempfile import pytest +from flyteidl.core import compiler_pb2 as _compiler_pb2 from mock import MagicMock, patch import flytekit.configuration @@ -10,10 +11,15 @@ from flytekit.exceptions import user as user_exceptions from flytekit.models import common as common_models from flytekit.models import security -from flytekit.models.core.identifier import ResourceType, WorkflowExecutionIdentifier +from flytekit.models.admin.workflow import Workflow, WorkflowClosure +from flytekit.models.core.compiler import CompiledWorkflowClosure +from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution +from flytekit.models.task import Task +from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote from flytekit.tools.translator import Options +from tests.flytekit.common.parameterizers import LIST_OF_TASK_CLOSURES CLIENT_METHODS = { ResourceType.WORKFLOW: "list_workflows_paginated", @@ -247,3 +253,43 @@ def test_generate_console_http_domain_sandbox_rewrite(mock_client): os.remove(temp_filename) except OSError: pass + + +def get_compiled_workflow_closure(): + """ + :rtype: flytekit.models.core.compiler.CompiledWorkflowClosure + """ + cwc_pb = _compiler_pb2.CompiledWorkflowClosure() + # So that tests that use this work when run from any directory + basepath = os.path.dirname(__file__) + filepath = os.path.abspath(os.path.join(basepath, "responses", "CompiledWorkflowClosure.pb")) + with open(filepath, "rb") as fh: + cwc_pb.ParseFromString(fh.read()) + + return CompiledWorkflowClosure.from_flyte_idl(cwc_pb) + + +@patch("flytekit.remote.remote.SynchronousFlyteClient") +def test_fetch_lazy(mock_client): + mock_client.get_task.return_value = Task( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), closure=LIST_OF_TASK_CLOSURES[0] + ) + + mock_client.get_workflow.return_value = Workflow( + id=Identifier(ResourceType.TASK, "p", "d", "n", "v"), + closure=WorkflowClosure(compiled_workflow=get_compiled_workflow_closure()), + ) + + remote = FlyteRemote(config=Config.auto(), default_project="p1", default_domain="d1") + lw = remote.fetch_workflow_lazy(name="wn", version="v") + assert isinstance(lw, LazyEntity) + assert lw._getter + assert lw._entity is None + assert lw.entity + + lt = remote.fetch_task_lazy(name="n", version="v") + assert isinstance(lw, LazyEntity) + assert lt._getter + assert lt._entity is None + tk = lt.entity + assert tk.name == "n" diff --git a/tests/flytekit/unit/remote/test_with_responses.py b/tests/flytekit/unit/remote/test_with_responses.py index ee3fbb4d8a..7dd7b97910 100644 --- a/tests/flytekit/unit/remote/test_with_responses.py +++ b/tests/flytekit/unit/remote/test_with_responses.py @@ -66,11 +66,11 @@ def test_normal_task(mock_client): ) admin_task = task_models.Task.from_flyte_idl(merge_sort_remotely) mock_client.get_task.return_value = admin_task - ft = rr.fetch_task(name="merge_sort_remotely", version="tst") + remote_task = rr.fetch_task(name="merge_sort_remotely", version="tst") @workflow def my_wf(numbers: typing.List[int], run_local_at_count: int) -> typing.List[int]: - t1_node = create_node(ft, numbers=numbers, run_local_at_count=run_local_at_count) + t1_node = create_node(remote_task, numbers=numbers, run_local_at_count=run_local_at_count) return t1_node.o0 serialization_settings = flytekit.configuration.SerializationSettings( diff --git a/tests/flytekit/unit/remote/test_wrapper_classes.py b/tests/flytekit/unit/remote/test_wrapper_classes.py index 4a08cb7724..82ba538883 100644 --- a/tests/flytekit/unit/remote/test_wrapper_classes.py +++ b/tests/flytekit/unit/remote/test_wrapper_classes.py @@ -9,7 +9,7 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.workflow import workflow -from flytekit.remote import FlyteWorkflow +from flytekit.remote import FlyteTask, FlyteWorkflow from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -58,11 +58,14 @@ def wf(b: int) -> int: serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, wf) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -79,7 +82,10 @@ def wf2(b: int) -> int: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans=lp_specs, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans=lp_specs, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.outputs) == 1 assert list(fwf.interface.inputs.keys()) == ["b"] @@ -111,7 +117,10 @@ def my_wf(a: int) -> str: task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows={}, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows={}, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(t) for k, t in task_templates.items()}, ) assert len(fwf.flyte_nodes[0].upstream_nodes) == 0 @@ -125,11 +134,14 @@ def parent(a: int) -> (str, str): serialized = OrderedDict() wf_spec = get_serializable(serialized, serialization_settings, parent) - sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + sub_wf_dict = {s.id: s for s in wf_spec.sub_workflows} fwf = FlyteWorkflow.promote_from_model( - wf_spec.template, sub_workflows=sub_wf_dict, node_launch_plans={}, tasks=task_templates + wf_spec.template, + sub_workflows=sub_wf_dict, + node_launch_plans={}, + tasks={k: FlyteTask.promote_from_model(v) for k, v in task_templates.items()}, ) # Test upstream nodes don't get confused by subworkflows assert len(fwf.flyte_nodes[0].upstream_nodes) == 0