From 366280363d2a4d12ec1f8b34a0324c2de6021f5e Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 27 Mar 2023 17:30:48 -0700 Subject: [PATCH] Fixed mypy Signed-off-by: Ketan Umare --- flytekit/bin/entrypoint.py | 3 +-- flytekit/core/interface.py | 11 +++++------ flytekit/core/map_task.py | 36 ++++++++++++++++++------------------ 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 659101704f..a9b7c313f0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -10,7 +10,6 @@ import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit import PythonFunctionTask from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -23,7 +22,7 @@ from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.map_task import MapPythonTask, MapTaskResolver +from flytekit.core.map_task import MapTaskResolver from flytekit.core.promise import VoidPromise from flytekit.exceptions import scopes as _scoped_exceptions from flytekit.exceptions import scopes as _scopes diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index dcd362488d..eae7a8e0cf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -21,11 +21,11 @@ T = typing.TypeVar("T") -def repr_kv(k: str, v: Union[str, Tuple[Type, Any]]): +def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: if isinstance(v, tuple): if v[1]: return f"{k}: {v[0]}={v[1]}" - v = v[0] + return f"{k}: {v[0]}" return f"{k}: {v}" @@ -79,8 +79,8 @@ def __init__( variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output( - collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) + class Output( # type: ignore + collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore ): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used @@ -450,8 +450,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... "Tuples should be used to indicate multiple return values, found only one return variable." ) return OrderedDict( - zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) - # type: ignore + zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) # type: ignore ) elif isinstance(return_annotation, tuple): if len(return_annotation) == 1: diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 1c4367de97..83b2542fe3 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -55,34 +55,34 @@ def __init__( self._partial = None if isinstance(python_function_task, functools.partial): self._partial = python_function_task - python_function_task = self._partial.func + actual_task = self._partial.func + else: + actual_task = python_function_task - if not isinstance(python_function_task, PythonFunctionTask): + if not isinstance(actual_task, PythonFunctionTask): raise ValueError("Map tasks can only compose of Python Functon Tasks currently") - if len(python_function_task.python_interface.outputs.keys()) > 1: + if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") - self._bound_inputs = set(bound_inputs) if bound_inputs else set() + self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set() if self._partial: self._bound_inputs = set(self._partial.keywords.keys()) - collection_interface = transform_interface_to_list_interface( - python_function_task.python_interface, self._bound_inputs - ) - self._run_task = python_function_task - _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function) + collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) + self._run_task: PythonFunctionTask = actual_task + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() name = f"{mod}.map_{f}_{h}" - self._cmd_prefix = None - self._max_concurrency = concurrency - self._min_success_ratio = min_success_ratio - self._array_task_interface = python_function_task.python_interface - if "metadata" not in kwargs and python_function_task.metadata: - kwargs["metadata"] = python_function_task.metadata - if "security_ctx" not in kwargs and python_function_task.security_context: - kwargs["security_ctx"] = python_function_task.security_context + self._cmd_prefix: typing.Optional[typing.List[str]] = None + self._max_concurrency: typing.Optional[int] = concurrency + self._min_success_ratio: typing.Optional[float] = min_success_ratio + self._array_task_interface = actual_task.python_interface + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context super().__init__( name=name, interface=collection_interface, @@ -124,7 +124,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): - self._cmd_prefix = cmd # type: ignore + self._cmd_prefix = cmd @contextmanager def prepare_target(self):