diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 2cdaa50365..562099c641 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,6 +1,6 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface @@ -75,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction return PythonFunctionTask +T = TypeVar("T") + + +@overload +def task( + _task_function: None = ..., + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]: + ... + + +@overload +def task( + _task_function: Callable[..., Any], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> PythonFunctionTask[T]: + ... + + def task( - _task_function: Optional[Callable] = None, - task_config: Optional[Any] = None, + _task_function: Optional[Callable[..., Any]] = None, + task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -96,7 +151,7 @@ def task( disable_deck: bool = True, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable, PythonFunctionTask]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]: """ This is the core decorator to use for any task type in flytekit. @@ -190,7 +245,7 @@ def foo2(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ - def wrapper(fn) -> PythonFunctionTask: + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize,