Skip to content

Commit

Permalink
Improve task decorator type hints with overload (#1631)
Browse files Browse the repository at this point in the history
Without the overload, the decorated function does not have the proper type of PythonFunctionTask, leading to spurious type errors when attempting to register the task on a FlyteRemote object

Signed-off-by: Matthew Hoffman <[email protected]>
  • Loading branch information
ringohoffman authored May 11, 2023
1 parent 4037fa0 commit 993201f
Showing 1 changed file with 60 additions and 5 deletions.
65 changes: 60 additions & 5 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
Expand Down Expand Up @@ -75,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
return PythonFunctionTask


T = TypeVar("T")


@overload
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]:
...


@overload
def task(
_task_function: Callable[..., Any],
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> PythonFunctionTask[T]:
...


def task(
_task_function: Optional[Callable] = None,
task_config: Optional[Any] = None,
_task_function: Optional[Callable[..., Any]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
Expand All @@ -96,7 +151,7 @@ def task(
disable_deck: bool = True,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -190,7 +245,7 @@ def foo2():
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""

def wrapper(fn) -> PythonFunctionTask:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
Expand Down

0 comments on commit 993201f

Please sign in to comment.