Skip to content

Commit

Permalink
Improve task type hint (#1711)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jun 27, 2023
1 parent 3b064dc commit 1052af8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
6 changes: 3 additions & 3 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def __init__(
self._bound_inputs = set(self._partial.keywords.keys())

collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
self._run_task: typing.Union[PythonFunctionTask, PythonInstanceTask] = actual_task # type: ignore
if isinstance(actual_task, PythonInstanceTask):
mod = actual_task.task_type
f = actual_task.lhs
else:
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
_, mod, f, _ = tracker.extract_task_module(typing.cast(PythonFunctionTask, actual_task).task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

Expand Down Expand Up @@ -176,7 +176,7 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
return self._run_task.get_config(settings)

@property
def run_task(self) -> PythonFunctionTask:
def run_task(self) -> typing.Union[PythonFunctionTask, PythonInstanceTask]:
return self._run_task

def __call__(self, *args, **kwargs):
Expand Down
11 changes: 6 additions & 5 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction


T = TypeVar("T")
FuncOut = TypeVar("FuncOut")


@overload
Expand All @@ -100,13 +101,13 @@ def task(
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]:
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]:
...


@overload
def task(
_task_function: Callable[..., Any],
_task_function: Callable[..., FuncOut],
task_config: Optional[T] = ...,
cache: bool = ...,
cache_serialize: bool = ...,
Expand All @@ -126,12 +127,12 @@ def task(
disable_deck: bool = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
) -> PythonFunctionTask[T]:
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]:
...


def task(
_task_function: Optional[Callable[..., Any]] = None,
_task_function: Optional[Callable[..., FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache_serialize: bool = False,
Expand All @@ -151,7 +152,7 @@ def task(
disable_deck: bool = True,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]:
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down
9 changes: 5 additions & 4 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)

T = typing.TypeVar("T")
FuncOut = typing.TypeVar("FuncOut")


class WorkflowFailurePolicy(Enum):
Expand Down Expand Up @@ -783,17 +784,17 @@ def workflow(
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]:
) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]:
...


@overload
def workflow(
_workflow_function: Callable[..., Any],
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
docs: Optional[Documentation] = ...,
) -> PythonFunctionWorkflow:
) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]:
...


Expand All @@ -802,7 +803,7 @@ def workflow(
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
docs: Optional[Documentation] = None,
) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down

0 comments on commit 1052af8

Please sign in to comment.