From 95975e20c8e774da00380c7a89473f0711369769 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Thu, 11 May 2023 00:12:18 -0700 Subject: [PATCH] Improve task decorator type hints with overload Without the overload, the decorated function does not have the proper type of PythonFunctionTask, leading to spurious type errors when attempting to register the task on a FlyteRemote object Signed-off-by: Matthew Hoffman --- flytekit/core/task.py | 65 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 2cdaa50365..562099c641 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -1,6 +1,6 @@ import datetime as _datetime from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload from flytekit.core.base_task import TaskMetadata, TaskResolverMixin from flytekit.core.interface import transform_function_to_interface @@ -75,9 +75,64 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction return PythonFunctionTask +T = TypeVar("T") + + +@overload +def task( + _task_function: None = ..., + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]: + ... + + +@overload +def task( + _task_function: Callable[..., Any], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> PythonFunctionTask[T]: + ... + + def task( - _task_function: Optional[Callable] = None, - task_config: Optional[Any] = None, + _task_function: Optional[Callable[..., Any]] = None, + task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, cache_version: str = "", @@ -96,7 +151,7 @@ def task( disable_deck: bool = True, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable, PythonFunctionTask]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]: """ This is the core decorator to use for any task type in flytekit. @@ -190,7 +245,7 @@ def foo2(): :param pod_template_name: The name of the existing PodTemplate resource which will be used in this task. """ - def wrapper(fn) -> PythonFunctionTask: + def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]: _metadata = TaskMetadata( cache=cache, cache_serialize=cache_serialize,