diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 49a05ce654..5e1d8f8166 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -107,7 +107,7 @@ def task( @overload def task( - _task_function: Callable[..., Any], + _task_function: Callable[..., FuncOut], task_config: Optional[T] = ..., cache: bool = ..., cache_serialize: bool = ..., @@ -127,7 +127,33 @@ def task( disable_deck: bool = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., -) -> Union[PythonFunctionTask[T], FuncOut]: +) -> PythonFunctionTask[T]: + ... + + +@overload +def task( + _task_function: Callable[..., FuncOut], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[..., FuncOut]: ... @@ -152,7 +178,7 @@ def task( disable_deck: bool = True, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T], FuncOut]: +) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]: """ This is the core decorator to use for any task type in flytekit. diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0fcab3dda7..e2b32490dc 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -55,6 +55,7 @@ ) T = typing.TypeVar("T") +FuncOut = typing.TypeVar("FuncOut") class WorkflowFailurePolicy(Enum): @@ -789,7 +790,7 @@ def workflow( @overload def workflow( - _workflow_function: Callable[..., Any], + _workflow_function: Callable[..., FuncOut], failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., docs: Optional[Documentation] = ..., @@ -797,12 +798,32 @@ def workflow( ... +@overload +def workflow( + _workflow_function: Callable[..., FuncOut], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> PythonFunctionWorkflow: + ... + + +@overload +def workflow( + _workflow_function: Callable[..., FuncOut], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> Callable[..., FuncOut]: + ... + + def workflow( _workflow_function: Optional[Callable[..., Any]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks.