Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throw warning for nested @Task functions #1727

Merged
merged 4 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def sandbox_execute(
self,
ctx: FlyteContext,
Expand Down Expand Up @@ -602,7 +606,7 @@ def dispatch_execute(
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute
_output_deck(self.name.split(".")[-1], new_user_params)

Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from typing import Optional, Tuple, Union, cast

from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.node import Node
from flytekit.core.promise import (
ComparisonExpression,
Expand Down Expand Up @@ -488,7 +488,7 @@ def conditional(name: str) -> ConditionalSection:
if ctx.compilation_state:
return ConditionalSection(name)
elif ctx.execution_state:
if ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state.is_local_execution():
# In case of Local workflow execution, we will actually evaluate the expression and based on the result
# make the branch to be active using `take_branch` method
from flytekit.core.context_manager import BranchEvalMode
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,12 @@ def with_params(
user_space_params=user_space_params if user_space_params else self.user_space_params,
)

def is_local_execution(self):
return (
self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
)


@dataclass(frozen=True)
class FlyteContext(object):
Expand Down Expand Up @@ -690,7 +696,7 @@ def enter_conditional_section(self) -> FlyteContext.Builder:
self.compilation_state = self.compilation_state.with_params(prefix=self.compilation_state.prefix)

if self.execution_state:
if self.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if self.execution_state.is_local_execution():
if self.in_a_condition:
if self.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
self.execution_state = self.execution_state.with_params()
Expand Down
5 changes: 4 additions & 1 deletion flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteDisapprovalException
Expand Down Expand Up @@ -116,6 +116,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
raise FlyteDisapprovalException(f"User did not approve the transaction for gate node {self.name}")

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION


def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type):
"""Create a Gate object that waits for user input of the specified type.
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]:
"""

ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self.interface.outputs
return self._run_task.interface.outputs
Expand All @@ -232,7 +232,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> type:
from these individual outputs as the final output value.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self._python_interface.outputs[k]
return self._run_task._python_interface.outputs[k]
Expand Down
7 changes: 2 additions & 5 deletions flytekit/core/node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Union

from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
from flytekit.core.context_manager import BranchEvalMode, FlyteContext
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import VoidPromise
Expand Down Expand Up @@ -144,10 +144,7 @@ def sub_wf():
# Handling local execution
# Note: execution state is set to TASK_EXECUTION when running dynamic task locally
# https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262
elif ctx.execution_state is not None and (
ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
or ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
):
elif ctx.execution_state and ctx.execution_state.is_local_execution():
if isinstance(entity, RemoteEntity):
raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}")

Expand Down
44 changes: 29 additions & 15 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,9 @@ class LocallyExecutable(Protocol):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
...

def local_execution_mode(self) -> ExecutionState.Mode:
...


def flyte_entity_call_handler(
entity: SupportsNodeCreation, *args, **kwargs
Expand Down Expand Up @@ -996,27 +999,38 @@ def flyte_entity_call_handler(
)

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
):
logger.error("You are not supposed to nest @Task/@Workflow inside a @Task!")
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
if ctx.execution_state and ctx.execution_state.is_local_execution():
mode = cast(LocallyExecutable, entity).local_execution_mode()
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(mode=mode))
) as child_ctx:
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)
else:
mode = cast(LocallyExecutable, entity).local_execution_mode()
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)
)
ctx.with_execution_state(ctx.new_execution_state().with_params(mode=mode))
) as child_ctx:
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
representing that newly generated workflow, instead of executing it.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# The rest of this function mimics the local_execute of the workflow. We can't use the workflow
# local_execute directly though since that converts inputs into Promises.
logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}")
Expand Down
7 changes: 4 additions & 3 deletions flytekit/core/reference_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Pro
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
return _workflow_model.NodeMetadata(name=extract_obj_name(self.name))

Expand All @@ -207,9 +210,7 @@ def __call__(self, *args, **kwargs):
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return self.compile(ctx, *args, **kwargs)
elif (
ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
):
elif ctx.execution_state and ctx.execution_state.is_local_execution():
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
return
return self.local_execute(ctx, **kwargs)
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection
from flytekit.core.context_manager import CompilationState, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.context_manager import (
CompilationState,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import (
Interface,
Expand Down Expand Up @@ -334,6 +340,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr

return create_task_output(new_promises, self.python_interface)

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION


class ImperativeWorkflow(WorkflowBase):
"""
Expand Down
3 changes: 3 additions & 0 deletions flytekit/remote/remote_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __call__(self, *args, **kwargs):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
return self.execute(**kwargs)

def local_execution_mode(self) -> ExecutionState.Mode:
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def execute(self, **kwargs) -> Any:
raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.")

Expand Down