Skip to content

Commit

Permalink
Improve task decorator type hints with overload
Browse files Browse the repository at this point in the history
Without the overload, the decorated function does not have the proper type of PythonFunctionTask, leading to spurious type errors when attempting to register the task on a FlyteRemote object
  • Loading branch information
ringohoffman committed May 11, 2023
1 parent e44b802 commit f1a7085
Showing 1 changed file with 60 additions and 5 deletions.
65 changes: 60 additions & 5 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload

from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
Expand Down Expand Up @@ -75,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
return PythonFunctionTask


T = TypeVar("T")


@overload
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]:
...


@overload
def task(
_task_function: Callable[..., Any],
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[_datetime.timedelta, int] = ...,
container_image: Optional[Union[str, ImageSpec]] = ...,
environment: Optional[Dict[str, str]] = ...,
requests: Optional[Resources] = ...,
limits: Optional[Resources] = ...,
secret_requests: Optional[List[Secret]] = ...,
execution_mode: PythonFunctionTask.ExecutionBehavior = ...,
task_resolver: Optional[TaskResolverMixin] = ...,
docs: Optional[Documentation] = ...,
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> PythonFunctionTask[T]:
...


def task(
_task_function: Optional[Callable] = None,
task_config: Optional[Any] = None,
_task_function: Optional[Callable[..., Any]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
Expand All @@ -96,7 +151,7 @@ def task(
disable_deck: bool = True,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -190,7 +245,7 @@ def foo2():
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
"""

def wrapper(fn) -> PythonFunctionTask:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
Expand Down

0 comments on commit f1a7085

Please sign in to comment.