diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index e1e80a4227..c26bdd6c6e 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -13,11 +13,14 @@ """ +from __future__ import annotations + from abc import ABC from collections import OrderedDict from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar, Union, cast +from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union, cast +from flytekit.core import launch_plan as _annotated_launch_plan from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring @@ -27,6 +30,7 @@ from flytekit.core.tracker import extract_task_module, is_functools_wrapped_module_level, isnested, istestfunction from flytekit.core.workflow import ( PythonFunctionWorkflow, + WorkflowBase, WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults, @@ -102,6 +106,9 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + node_dependency_hints: Optional[ + Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]] + ] = None, **kwargs, ): """ @@ -112,6 +119,9 @@ def __init__( :param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example executing normally or specially handling a dynamic case. :param str task_type: String task type to be associated with this Task + :param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints: + A list of tasks, launchplans, or workflows that this task depends on. This is only + for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime. """ if task_function is None: raise ValueError("TaskFunction is a required parameter for PythonFunctionTask") @@ -145,12 +155,24 @@ def __init__( ) self._task_function = task_function self._execution_mode = execution_mode + self._node_dependency_hints = node_dependency_hints + if self._node_dependency_hints is not None and self._execution_mode != self.ExecutionBehavior.DYNAMIC: + raise ValueError( + "node_dependency_hints should only be used on dynamic tasks. On static tasks and " + "workflows its redundant because flyte can find the node dependencies automatically" + ) self._wf = None # For dynamic tasks @property def execution_mode(self) -> ExecutionBehavior: return self._execution_mode + @property + def node_dependency_hints( + self, + ) -> Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]]: + return self._node_dependency_hints + @property def task_function(self): return self._task_function diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 547abd41fa..a99fbf599e 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union, overload +from flytekit.core import launch_plan as _annotated_launchplan +from flytekit.core import workflow as _annotated_workflow from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface from flytekit.core.pod_template import PodTemplate @@ -97,6 +101,9 @@ def task( limits: Optional[Resources] = ..., secret_requests: Optional[List[Secret]] = ..., execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + node_dependency_hints: Optional[ + Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + ] = ..., task_resolver: Optional[TaskResolverMixin] = ..., docs: Optional[Documentation] = ..., disable_deck: Optional[bool] = ..., @@ -125,6 +132,9 @@ def task( limits: Optional[Resources] = ..., secret_requests: Optional[List[Secret]] = ..., execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + node_dependency_hints: Optional[ + Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + ] = ..., task_resolver: Optional[TaskResolverMixin] = ..., docs: Optional[Documentation] = ..., disable_deck: Optional[bool] = ..., @@ -152,6 +162,9 @@ def task( limits: Optional[Resources] = None, secret_requests: Optional[List[Secret]] = None, execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, + node_dependency_hints: Optional[ + Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]] + ] = None, task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: Optional[bool] = None, @@ -246,6 +259,28 @@ def foo2(): Refer to :py:class:`Secret` to understand how to specify the request for a secret. It may change based on the backend provider. :param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically. + :param node_dependency_hints: A list of tasks, launchplans, or workflows that this task depends on. This is only + for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime. + Even on dynamic tasks this is optional, but in some scenarios it will make registering the workflow easier, + because it allows registration to be done the same as for static tasks/workflows. + + For example this is useful to run launchplans dynamically, because launchplans must be registered on flyteadmin + before they can be run. Tasks and workflows do not have this requirement. + + .. code-block:: python + + @workflow + def workflow0(): + ... + + launchplan0 = LaunchPlan.get_or_create(workflow0) + + # Specify node_dependency_hints so that launchplan0 will be registered on flyteadmin, despite this being a + # dynamic task. + @dynamic(node_dependency_hints=[launchplan0]) + def launch_dynamically(): + # To run a sub-launchplan it must have previously been registered on flyteadmin. + return [launchplan0]*10 :param task_resolver: Provide a custom task resolver. :param disable_deck: (deprecated) If true, this task will not output deck html file :param enable_deck: If true, this task will output deck html file @@ -276,6 +311,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: limits=limits, secret_requests=secret_requests, execution_mode=execution_mode, + node_dependency_hints=node_dependency_hints, task_resolver=task_resolver, disable_deck=disable_deck, enable_deck=enable_deck, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 7c5e0c65f7..61a18c7b64 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload from flytekit.core import constants as _common_constants +from flytekit.core import launch_plan as _annotated_launch_plan from flytekit.core.base_task import PythonTask, Task from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.condition import ConditionalSection, conditional @@ -26,7 +27,6 @@ transform_inputs_to_parameters, transform_interface_to_typed_interface, ) -from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import ( NodeOutput, @@ -528,7 +528,7 @@ def create_conditional(self, name: str) -> ConditionalSection: FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) return conditional(name=name) - def add_entity(self, entity: Union[PythonTask, LaunchPlan, WorkflowBase], **kwargs) -> Node: + def add_entity(self, entity: Union[PythonTask, _annotated_launch_plan.LaunchPlan, WorkflowBase], **kwargs) -> Node: """ Anytime you add an entity, all the inputs to the entity must be bound. """ @@ -611,7 +611,7 @@ def add_workflow_output( def add_task(self, task: PythonTask, **kwargs) -> Node: return self.add_entity(task, **kwargs) - def add_launch_plan(self, launch_plan: LaunchPlan, **kwargs) -> Node: + def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwargs) -> Node: return self.add_entity(launch_plan, **kwargs) def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node: diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 21f9e1b376..5d182c89b8 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -159,8 +159,10 @@ def fn(settings: SerializationSettings) -> List[str]: def get_serializable_task( + entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, + options: Optional[Options] = None, ) -> TaskSpec: task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, @@ -176,6 +178,10 @@ def get_serializable_task( # during dynamic serialization settings = settings.with_serialized_context() + if entity.node_dependency_hints is not None: + for entity_hint in entity.node_dependency_hints: + get_serializable(entity_mapping, settings, entity_hint, options) + container = entity.get_container(settings) # This pod will be incorrect when doing fast serialize pod = entity.get_k8s_pod(settings) @@ -713,7 +719,7 @@ def get_serializable( cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(settings, entity) + cp_entity = get_serializable_task(entity_mapping, settings, entity) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) diff --git a/plugins/flytekit-flyin/tests/test_flyin_plugin.py b/plugins/flytekit-flyin/tests/test_flyin_plugin.py index 23cde516ce..ce2758e595 100644 --- a/plugins/flytekit-flyin/tests/test_flyin_plugin.py +++ b/plugins/flytekit-flyin/tests/test_flyin_plugin.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import mock import pytest from flytekitplugins.flyin import ( @@ -343,7 +345,7 @@ def t(): project="p", domain="d", version="v", image_config=default_image_config ) - serialized_task = get_serializable_task(default_serialization_settings, t) + serialized_task = get_serializable_task(OrderedDict(), default_serialization_settings, t) assert serialized_task.template.config == {"link_type": "vscode", "port": "8081"} diff --git a/tests/flytekit/unit/core/test_container_task.py b/tests/flytekit/unit/core/test_container_task.py index 3ac4a47cc4..c89ec11345 100644 --- a/tests/flytekit/unit/core/test_container_task.py +++ b/tests/flytekit/unit/core/test_container_task.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import pytest from kubernetes.client.models import ( V1Affinity, @@ -72,7 +74,7 @@ def test_pod_template(): ################# # Test Serialization ################# - ts = get_serializable_task(default_serialization_settings, ct) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, ct) assert ts.template.metadata.pod_template_name == "my-base-template" assert ts.template.container is None assert ts.template.k8s_pod is not None diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index b9b0ebd3fa..44cc19b5c3 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -1,4 +1,5 @@ import typing +from collections import OrderedDict import pytest @@ -13,6 +14,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.models.literals import LiteralMap +from flytekit.tools.translator import get_serializable_task settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -262,3 +264,29 @@ def wf(wf_in: str) -> typing.List[str]: res = dt(ss="hello") assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"] + + +def test_node_dependency_hints_are_serialized(): + @task + def t1() -> int: + return 0 + + @task + def t2() -> int: + return 0 + + @dynamic(node_dependency_hints=[t1, t2]) + def dt(mode: int) -> int: + if mode == 1: + return t1() + if mode == 2: + return t2() + + raise ValueError("Invalid mode") + + entity_mapping = OrderedDict() + get_serializable_task(entity_mapping, settings, dt) + + serialised_entities_iterator = iter(entity_mapping.values()) + assert "t1" in next(serialised_entities_iterator).template.id.name + assert "t2" in next(serialised_entities_iterator).template.id.name diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index fed612c98c..f5c7136b0b 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from typing import Any import pytest @@ -73,7 +74,7 @@ def test_get_container(default_serialization_settings): assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar"} - ts = get_serializable_task(default_serialization_settings, task) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, task) assert ts.template.container.image == "docker.io/xyz:some-git-hash" assert ts.template.container.env == {"FOO": "bar"} @@ -86,7 +87,7 @@ def test_get_container_with_task_envvars(default_serialization_settings): assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"FOO": "bar", "HAM": "spam"} - ts = get_serializable_task(default_serialization_settings, task_with_env_vars) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_env_vars) assert ts.template.container.image == "docker.io/xyz:some-git-hash" assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} @@ -96,7 +97,7 @@ def test_get_container_without_serialization_settings_envvars(minimal_serializat assert c.image == "docker.io/xyz:some-git-hash" assert c.env == {"HAM": "spam"} - ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars) + ts = get_serializable_task(OrderedDict(), minimal_serialization_settings, task_with_env_vars) assert ts.template.container.image == "docker.io/xyz:some-git-hash" assert ts.template.container.env == {"HAM": "spam"} @@ -215,7 +216,7 @@ def test_pod_template(default_serialization_settings): ################# # Test Serialization ################# - ts = get_serializable_task(default_serialization_settings, task_with_pod_template) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_pod_template) assert ts.template.container is None # k8s_pod content is already verified above, so only check the existence here assert ts.template.k8s_pod is not None @@ -290,7 +291,7 @@ def test_minimum_pod_template(default_serialization_settings): ################# # Test Serialization ################# - ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_minimum_pod_template) assert ts.template.container is None # k8s_pod content is already verified above, so only check the existence here assert ts.template.k8s_pod is not None diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 498228f3fe..bec160cd7e 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import pytest from kubernetes.client.models import V1Container, V1PodSpec @@ -205,8 +207,20 @@ def func_with_pod_template(i: str): ################# # Test Serialization ################# - ts = get_serializable_task(default_serialization_settings, func_with_pod_template) + ts = get_serializable_task(OrderedDict(), default_serialization_settings, func_with_pod_template) assert ts.template.container is None # k8s_pod content is already verified above, so only check the existence here assert ts.template.k8s_pod is not None assert ts.template.metadata.pod_template_name == "A" + + +def test_node_dependency_hints_are_not_allowed(): + @task + def t1(i: str): + pass + + with pytest.raises(ValueError, match="node_dependency_hints should only be used on dynamic tasks"): + + @task(node_dependency_hints=[t1]) + def t2(i: str): + pass diff --git a/tests/flytekit/unit/core/test_utils.py b/tests/flytekit/unit/core/test_utils.py index 910156f50a..ca0d07565d 100644 --- a/tests/flytekit/unit/core/test_utils.py +++ b/tests/flytekit/unit/core/test_utils.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import pytest import flytekit @@ -90,7 +92,7 @@ def t() -> str: assert t() == "hello world" assert t.get_config(settings=ss) == {} - ts = get_serializable_task(ss, t) + ts = get_serializable_task(OrderedDict(), ss, t) assert ts.template.config == {"foo": "bar"} @task @@ -98,5 +100,5 @@ def t() -> str: def t() -> str: return "hello world" - ts = get_serializable_task(ss, t) + ts = get_serializable_task(OrderedDict(), ss, t) assert ts.template.config == {"foo": "baz"}