Skip to content

Commit

Permalink
Improve workflow decorator type hints with overload (#1635)
Browse files Browse the repository at this point in the history
Previously, the workflow decorator is hinted as always returning a WorkflowBase, which is not true when _workflow_function is None; similar to #1631, we propose using typing.overload to differentiate the return type of workflow based on the value of _workflow_function

Signed-off-by: Matthew Hoffman <[email protected]>
  • Loading branch information
ringohoffman authored May 12, 2023
1 parent 0a1f289 commit 71d7898
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from enum import Enum
from functools import update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload

from typing_extensions import get_args

Expand Down Expand Up @@ -653,7 +653,7 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver):

def __init__(
self,
workflow_function: Callable,
workflow_function: Callable[..., Any],
metadata: WorkflowMetadata,
default_metadata: WorkflowMetadataDefaults,
docstring: Optional[Docstring] = None,
Expand Down Expand Up @@ -777,12 +777,32 @@ def execute(self, **kwargs):
return exception_scopes.user_entry_point(self._workflow_function)(**kwargs)


@overload
def workflow(
_workflow_function=None,
_workflow_function: None = ...,
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]:
...


@overload
def workflow(
_workflow_function: Callable[..., Any],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> PythonFunctionWorkflow:
...


def workflow(
_workflow_function: Optional[Callable[..., Any]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
docs: Optional[Documentation] = None,
) -> WorkflowBase:
) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -813,7 +833,7 @@ def workflow(
:param docs: Description entity for the workflow
"""

def wrapper(fn):
def wrapper(fn: Callable[..., Any]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand All @@ -828,10 +848,10 @@ def wrapper(fn):
update_wrapper(workflow_instance, fn)
return workflow_instance

if _workflow_function:
if _workflow_function is not None:
return wrapper(_workflow_function)
else:
return wrapper # type: ignore
return wrapper


class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore
Expand Down

0 comments on commit 71d7898

Please sign in to comment.