Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node_dependency_hints for dynamic tasks #2015

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@

"""

from __future__ import annotations

from abc import ABC
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar, Union, cast
from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union, cast

from flytekit.core import launch_plan as _annotated_launch_plan
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
Expand All @@ -27,6 +30,7 @@
from flytekit.core.tracker import extract_task_module, is_functools_wrapped_module_level, isnested, istestfunction
from flytekit.core.workflow import (
PythonFunctionWorkflow,
WorkflowBase,
WorkflowFailurePolicy,
WorkflowMetadata,
WorkflowMetadataDefaults,
Expand Down Expand Up @@ -102,6 +106,9 @@ def __init__(
ignore_input_vars: Optional[List[str]] = None,
execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT,
task_resolver: Optional[TaskResolverMixin] = None,
node_dependency_hints: Optional[
Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]
] = None,
**kwargs,
):
"""
Expand All @@ -112,6 +119,9 @@ def __init__(
:param Optional[ExecutionBehavior] execution_mode: Defines how the execution should behave, for example
executing normally or specially handling a dynamic case.
:param str task_type: String task type to be associated with this Task
:param Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]] node_dependency_hints:
A list of tasks, launchplans, or workflows that this task depends on. This is only
for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime.
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
Expand Down Expand Up @@ -145,12 +155,24 @@ def __init__(
)
self._task_function = task_function
self._execution_mode = execution_mode
self._node_dependency_hints = node_dependency_hints
if self._node_dependency_hints is not None and self._execution_mode != self.ExecutionBehavior.DYNAMIC:
raise ValueError(
"node_dependency_hints should only be used on dynamic tasks. On static tasks and "
"workflows its redundant because flyte can find the node dependencies automatically"
)
self._wf = None # For dynamic tasks

@property
def execution_mode(self) -> ExecutionBehavior:
return self._execution_mode

@property
def node_dependency_hints(
self,
) -> Optional[Iterable[Union["PythonFunctionTask", "_annotated_launch_plan.LaunchPlan", WorkflowBase]]]:
return self._node_dependency_hints

@property
def task_function(self):
return self._task_function
Expand Down
38 changes: 37 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union, overload

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
Expand Down Expand Up @@ -97,6 +101,9 @@ def task(
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
node_dependency_hints: Optional[
Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]]
] = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: Optional[bool] = ...,
Expand Down Expand Up @@ -125,6 +132,9 @@ def task(
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
node_dependency_hints: Optional[
Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]]
] = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: Optional[bool] = ...,
Expand Down Expand Up @@ -152,6 +162,9 @@ def task(
limits: Optional[Resources] = None,
secret_requests: Optional[List[Secret]] = None,
execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT,
node_dependency_hints: Optional[
Iterable[Union[PythonFunctionTask, _annotated_launchplan.LaunchPlan, _annotated_workflow.WorkflowBase]]
] = None,
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: Optional[bool] = None,
Expand Down Expand Up @@ -246,6 +259,28 @@ def foo2():
Refer to :py:class:`Secret` to understand how to specify the request for a secret. It
may change based on the backend provider.
:param execution_mode: This is mainly for internal use. Please ignore. It is filled in automatically.
:param node_dependency_hints: A list of tasks, launchplans, or workflows that this task depends on. This is only
for dynamic tasks/workflows, where flyte cannot automatically determine the dependencies prior to runtime.
Even on dynamic tasks this is optional, but in some scenarios it will make registering the workflow easier,
because it allows registration to be done the same as for static tasks/workflows.

For example this is useful to run launchplans dynamically, because launchplans must be registered on flyteadmin
before they can be run. Tasks and workflows do not have this requirement.

.. code-block:: python

@workflow
def workflow0():
...

launchplan0 = LaunchPlan.get_or_create(workflow0)

# Specify node_dependency_hints so that launchplan0 will be registered on flyteadmin, despite this being a
# dynamic task.
@dynamic(node_dependency_hints=[launchplan0])
def launch_dynamically():
# To run a sub-launchplan it must have previously been registered on flyteadmin.
return [launchplan0]*10
:param task_resolver: Provide a custom task resolver.
:param disable_deck: (deprecated) If true, this task will not output deck html file
:param enable_deck: If true, this task will output deck html file
Expand Down Expand Up @@ -276,6 +311,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
limits=limits,
secret_requests=secret_requests,
execution_mode=execution_mode,
node_dependency_hints=node_dependency_hints,
task_resolver=task_resolver,
disable_deck=disable_deck,
enable_deck=enable_deck,
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload

from flytekit.core import constants as _common_constants
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import PythonTask, Task
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection, conditional
Expand All @@ -26,7 +27,6 @@
transform_inputs_to_parameters,
transform_interface_to_typed_interface,
)
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
NodeOutput,
Expand Down Expand Up @@ -528,7 +528,7 @@ def create_conditional(self, name: str) -> ConditionalSection:
FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state))
return conditional(name=name)

def add_entity(self, entity: Union[PythonTask, LaunchPlan, WorkflowBase], **kwargs) -> Node:
def add_entity(self, entity: Union[PythonTask, _annotated_launch_plan.LaunchPlan, WorkflowBase], **kwargs) -> Node:
"""
Anytime you add an entity, all the inputs to the entity must be bound.
"""
Expand Down Expand Up @@ -611,7 +611,7 @@ def add_workflow_output(
def add_task(self, task: PythonTask, **kwargs) -> Node:
return self.add_entity(task, **kwargs)

def add_launch_plan(self, launch_plan: LaunchPlan, **kwargs) -> Node:
def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwargs) -> Node:
return self.add_entity(launch_plan, **kwargs)

def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node:
Expand Down
8 changes: 7 additions & 1 deletion flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ def fn(settings: SerializationSettings) -> List[str]:


def get_serializable_task(
entity_mapping: OrderedDict,
settings: SerializationSettings,
entity: FlyteLocalEntity,
options: Optional[Options] = None,
) -> TaskSpec:
task_id = _identifier_model.Identifier(
_identifier_model.ResourceType.TASK,
Expand All @@ -176,6 +178,10 @@ def get_serializable_task(
# during dynamic serialization
settings = settings.with_serialized_context()

if entity.node_dependency_hints is not None:
for entity_hint in entity.node_dependency_hints:
get_serializable(entity_mapping, settings, entity_hint, options)

container = entity.get_container(settings)
# This pod will be incorrect when doing fast serialize
pod = entity.get_k8s_pod(settings)
Expand Down Expand Up @@ -713,7 +719,7 @@ def get_serializable(
cp_entity = get_reference_spec(entity_mapping, settings, entity)

elif isinstance(entity, PythonTask):
cp_entity = get_serializable_task(settings, entity)
cp_entity = get_serializable_task(entity_mapping, settings, entity)

elif isinstance(entity, WorkflowBase):
cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options)
Expand Down
4 changes: 3 additions & 1 deletion plugins/flytekit-flyin/tests/test_flyin_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import mock
import pytest
from flytekitplugins.flyin import (
Expand Down Expand Up @@ -343,7 +345,7 @@ def t():
project="p", domain="d", version="v", image_config=default_image_config
)

serialized_task = get_serializable_task(default_serialization_settings, t)
serialized_task = get_serializable_task(OrderedDict(), default_serialization_settings, t)
assert serialized_task.template.config == {"link_type": "vscode", "port": "8081"}


Expand Down
4 changes: 3 additions & 1 deletion tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import pytest
from kubernetes.client.models import (
V1Affinity,
Expand Down Expand Up @@ -72,7 +74,7 @@ def test_pod_template():
#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, ct)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, ct)
assert ts.template.metadata.pod_template_name == "my-base-template"
assert ts.template.container is None
assert ts.template.k8s_pod is not None
Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from collections import OrderedDict

import pytest

Expand All @@ -13,6 +14,7 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.literals import LiteralMap
from flytekit.tools.translator import get_serializable_task

settings = flytekit.configuration.SerializationSettings(
project="test_proj",
Expand Down Expand Up @@ -262,3 +264,29 @@ def wf(wf_in: str) -> typing.List[str]:

res = dt(ss="hello")
assert res == ["In t2 string is hello", "In t3 string is In t2 string is hello"]


def test_node_dependency_hints_are_serialized():
@task
def t1() -> int:
return 0

@task
def t2() -> int:
return 0

@dynamic(node_dependency_hints=[t1, t2])
def dt(mode: int) -> int:
if mode == 1:
return t1()
if mode == 2:
return t2()

raise ValueError("Invalid mode")

entity_mapping = OrderedDict()
get_serializable_task(entity_mapping, settings, dt)

serialised_entities_iterator = iter(entity_mapping.values())
assert "t1" in next(serialised_entities_iterator).template.id.name
assert "t2" in next(serialised_entities_iterator).template.id.name
11 changes: 6 additions & 5 deletions tests/flytekit/unit/core/test_python_auto_container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
from typing import Any

import pytest
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_get_container(default_serialization_settings):
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"FOO": "bar"}

ts = get_serializable_task(default_serialization_settings, task)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, task)
assert ts.template.container.image == "docker.io/xyz:some-git-hash"
assert ts.template.container.env == {"FOO": "bar"}

Expand All @@ -86,7 +87,7 @@ def test_get_container_with_task_envvars(default_serialization_settings):
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"FOO": "bar", "HAM": "spam"}

ts = get_serializable_task(default_serialization_settings, task_with_env_vars)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_env_vars)
assert ts.template.container.image == "docker.io/xyz:some-git-hash"
assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"}

Expand All @@ -96,7 +97,7 @@ def test_get_container_without_serialization_settings_envvars(minimal_serializat
assert c.image == "docker.io/xyz:some-git-hash"
assert c.env == {"HAM": "spam"}

ts = get_serializable_task(minimal_serialization_settings, task_with_env_vars)
ts = get_serializable_task(OrderedDict(), minimal_serialization_settings, task_with_env_vars)
assert ts.template.container.image == "docker.io/xyz:some-git-hash"
assert ts.template.container.env == {"HAM": "spam"}

Expand Down Expand Up @@ -215,7 +216,7 @@ def test_pod_template(default_serialization_settings):
#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, task_with_pod_template)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_pod_template)
assert ts.template.container is None
# k8s_pod content is already verified above, so only check the existence here
assert ts.template.k8s_pod is not None
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_minimum_pod_template(default_serialization_settings):
#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, task_with_minimum_pod_template)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, task_with_minimum_pod_template)
assert ts.template.container is None
# k8s_pod content is already verified above, so only check the existence here
assert ts.template.k8s_pod is not None
Expand Down
16 changes: 15 additions & 1 deletion tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import pytest
from kubernetes.client.models import V1Container, V1PodSpec

Expand Down Expand Up @@ -205,8 +207,20 @@ def func_with_pod_template(i: str):
#################
# Test Serialization
#################
ts = get_serializable_task(default_serialization_settings, func_with_pod_template)
ts = get_serializable_task(OrderedDict(), default_serialization_settings, func_with_pod_template)
assert ts.template.container is None
# k8s_pod content is already verified above, so only check the existence here
assert ts.template.k8s_pod is not None
assert ts.template.metadata.pod_template_name == "A"


def test_node_dependency_hints_are_not_allowed():
@task
def t1(i: str):
pass

with pytest.raises(ValueError, match="node_dependency_hints should only be used on dynamic tasks"):

@task(node_dependency_hints=[t1])
def t2(i: str):
pass
6 changes: 4 additions & 2 deletions tests/flytekit/unit/core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import OrderedDict

import pytest

import flytekit
Expand Down Expand Up @@ -90,13 +92,13 @@ def t() -> str:
assert t() == "hello world"
assert t.get_config(settings=ss) == {}

ts = get_serializable_task(ss, t)
ts = get_serializable_task(OrderedDict(), ss, t)
assert ts.template.config == {"foo": "bar"}

@task
@my_decorator
def t() -> str:
return "hello world"

ts = get_serializable_task(ss, t)
ts = get_serializable_task(OrderedDict(), ss, t)
assert ts.template.config == {"foo": "baz"}
Loading