From 6e9bcbe25871a983aba5615c40855afb328b903a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 25 Jun 2023 00:49:40 -0700 Subject: [PATCH 1/5] Improve task type hint Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 2 +- flytekit/core/map_task.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 6556ac1469..76788823ab 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -300,7 +300,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + def __call__(self, *args, **kwargs) -> Any: return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 52325ecb59..8005125034 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -78,12 +78,12 @@ def __init__( self._bound_inputs = set(self._partial.keywords.keys()) collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) - self._run_task: PythonFunctionTask = actual_task + self._run_task: typing.Union[PythonFunctionTask, PythonInstanceTask] = actual_task # type: ignore if isinstance(actual_task, PythonInstanceTask): mod = actual_task.task_type f = actual_task.lhs else: - _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + _, mod, f, _ = tracker.extract_task_module(typing.cast(PythonFunctionTask, actual_task).task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" @@ -168,7 +168,7 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] return self._run_task.get_config(settings) @property - def run_task(self) -> PythonFunctionTask: + def run_task(self) -> typing.Union[PythonFunctionTask, PythonInstanceTask]: return self._run_task def __call__(self, *args, **kwargs): From f9e3ca7fdaba4fd82bd06b5ec7b796077eab95a5 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 26 Jun 2023 12:32:52 -0700 Subject: [PATCH 2/5] Use func out Signed-off-by: Kevin Su --- flytekit/core/base_task.py | 2 +- flytekit/core/task.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 76788823ab..6556ac1469 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -300,7 +300,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 562099c641..49a05ce654 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -76,6 +76,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction T = TypeVar("T") +FuncOut = TypeVar("FuncOut") @overload @@ -126,7 +127,7 @@ def task( disable_deck: bool = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., -) -> PythonFunctionTask[T]: +) -> Union[PythonFunctionTask[T], FuncOut]: ... @@ -151,7 +152,7 @@ def task( disable_deck: bool = True, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T]]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T], FuncOut]: """ This is the core decorator to use for any task type in flytekit. From 7739361769354d411454597112be180b213ae9aa Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 26 Jun 2023 12:51:37 -0700 Subject: [PATCH 3/5] update Signed-off-by: Kevin Su --- flytekit/core/task.py | 32 +++++++++++++++++++++++++++++--- flytekit/core/workflow.py | 25 +++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 49a05ce654..5e1d8f8166 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -107,7 +107,7 @@ def task( @overload def task( - _task_function: Callable[..., Any], + _task_function: Callable[..., FuncOut], task_config: Optional[T] = ..., cache: bool = ..., cache_serialize: bool = ..., @@ -127,7 +127,33 @@ def task( disable_deck: bool = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., -) -> Union[PythonFunctionTask[T], FuncOut]: +) -> PythonFunctionTask[T]: + ... + + +@overload +def task( + _task_function: Callable[..., FuncOut], + task_config: Optional[T] = ..., + cache: bool = ..., + cache_serialize: bool = ..., + cache_version: str = ..., + retries: int = ..., + interruptible: Optional[bool] = ..., + deprecated: str = ..., + timeout: Union[_datetime.timedelta, int] = ..., + container_image: Optional[Union[str, ImageSpec]] = ..., + environment: Optional[Dict[str, str]] = ..., + requests: Optional[Resources] = ..., + limits: Optional[Resources] = ..., + secret_requests: Optional[List[Secret]] = ..., + execution_mode: PythonFunctionTask.ExecutionBehavior = ..., + task_resolver: Optional[TaskResolverMixin] = ..., + docs: Optional[Documentation] = ..., + disable_deck: bool = ..., + pod_template: Optional["PodTemplate"] = ..., + pod_template_name: Optional[str] = ..., +) -> Callable[..., FuncOut]: ... @@ -152,7 +178,7 @@ def task( disable_deck: bool = True, pod_template: Optional["PodTemplate"] = None, pod_template_name: Optional[str] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionTask[T]], PythonFunctionTask[T], FuncOut]: +) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]: """ This is the core decorator to use for any task type in flytekit. diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 0fcab3dda7..e2b32490dc 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -55,6 +55,7 @@ ) T = typing.TypeVar("T") +FuncOut = typing.TypeVar("FuncOut") class WorkflowFailurePolicy(Enum): @@ -789,7 +790,7 @@ def workflow( @overload def workflow( - _workflow_function: Callable[..., Any], + _workflow_function: Callable[..., FuncOut], failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., docs: Optional[Documentation] = ..., @@ -797,12 +798,32 @@ def workflow( ... +@overload +def workflow( + _workflow_function: Callable[..., FuncOut], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> PythonFunctionWorkflow: + ... + + +@overload +def workflow( + _workflow_function: Callable[..., FuncOut], + failure_policy: Optional[WorkflowFailurePolicy] = ..., + interruptible: bool = ..., + docs: Optional[Documentation] = ..., +) -> Callable[..., FuncOut]: + ... + + def workflow( _workflow_function: Optional[Callable[..., Any]] = None, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow]: +) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. From 289045e694085235867829a379ab213aa5869bbb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 26 Jun 2023 13:57:01 -0700 Subject: [PATCH 4/5] update Signed-off-by: Kevin Su --- flytekit/core/task.py | 4 ++-- flytekit/core/workflow.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 5e1d8f8166..6c911ae0ad 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -101,7 +101,7 @@ def task( disable_deck: bool = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., -) -> Callable[[Callable[..., Any]], PythonFunctionTask[T]]: +) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]: ... @@ -158,7 +158,7 @@ def task( def task( - _task_function: Optional[Callable[..., Any]] = None, + _task_function: Optional[Callable[..., FuncOut]] = None, task_config: Optional[T] = None, cache: bool = False, cache_serialize: bool = False, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index e2b32490dc..76f28911fb 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -784,7 +784,7 @@ def workflow( failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., docs: Optional[Documentation] = ..., -) -> Callable[[Callable[..., Any]], PythonFunctionWorkflow]: +) -> Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow]: ... @@ -823,7 +823,7 @@ def workflow( failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: bool = False, docs: Optional[Documentation] = None, -) -> Union[Callable[[Callable[..., Any]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]: +) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow, Callable[..., FuncOut]]: """ This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG of tasks using the data flow between tasks. From 4353132baf4e6156d32431c2b3a9e355a340655d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 26 Jun 2023 17:46:37 -0700 Subject: [PATCH 5/5] lint Signed-off-by: Kevin Su --- flytekit/core/task.py | 28 +--------------------------- flytekit/core/workflow.py | 22 +--------------------- 2 files changed, 2 insertions(+), 48 deletions(-) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 6c911ae0ad..ec9fbb057f 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -127,33 +127,7 @@ def task( disable_deck: bool = ..., pod_template: Optional["PodTemplate"] = ..., pod_template_name: Optional[str] = ..., -) -> PythonFunctionTask[T]: - ... - - -@overload -def task( - _task_function: Callable[..., FuncOut], - task_config: Optional[T] = ..., - cache: bool = ..., - cache_serialize: bool = ..., - cache_version: str = ..., - retries: int = ..., - interruptible: Optional[bool] = ..., - deprecated: str = ..., - timeout: Union[_datetime.timedelta, int] = ..., - container_image: Optional[Union[str, ImageSpec]] = ..., - environment: Optional[Dict[str, str]] = ..., - requests: Optional[Resources] = ..., - limits: Optional[Resources] = ..., - secret_requests: Optional[List[Secret]] = ..., - execution_mode: PythonFunctionTask.ExecutionBehavior = ..., - task_resolver: Optional[TaskResolverMixin] = ..., - docs: Optional[Documentation] = ..., - disable_deck: bool = ..., - pod_template: Optional["PodTemplate"] = ..., - pod_template_name: Optional[str] = ..., -) -> Callable[..., FuncOut]: +) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]: ... diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 76f28911fb..415842717c 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -794,27 +794,7 @@ def workflow( failure_policy: Optional[WorkflowFailurePolicy] = ..., interruptible: bool = ..., docs: Optional[Documentation] = ..., -) -> PythonFunctionWorkflow: - ... - - -@overload -def workflow( - _workflow_function: Callable[..., FuncOut], - failure_policy: Optional[WorkflowFailurePolicy] = ..., - interruptible: bool = ..., - docs: Optional[Documentation] = ..., -) -> PythonFunctionWorkflow: - ... - - -@overload -def workflow( - _workflow_function: Callable[..., FuncOut], - failure_policy: Optional[WorkflowFailurePolicy] = ..., - interruptible: bool = ..., - docs: Optional[Documentation] = ..., -) -> Callable[..., FuncOut]: +) -> Union[PythonFunctionWorkflow, Callable[..., FuncOut]]: ...