Skip to content

Commit

Permalink
Fixed mypy
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed Mar 28, 2023
1 parent adaf6da commit 3662803
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 26 deletions.
3 changes: 1 addition & 2 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import click as _click
from flyteidl.core import literals_pb2 as _literals_pb2

from flytekit import PythonFunctionTask
from flytekit.configuration import (
SERIALIZED_CONTEXT_ENV_VAR,
FastSerializationSettings,
Expand All @@ -23,7 +22,7 @@
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.map_task import MapPythonTask, MapTaskResolver
from flytekit.core.map_task import MapTaskResolver
from flytekit.core.promise import VoidPromise
from flytekit.exceptions import scopes as _scoped_exceptions
from flytekit.exceptions import scopes as _scopes
Expand Down
11 changes: 5 additions & 6 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
T = typing.TypeVar("T")


def repr_kv(k: str, v: Union[str, Tuple[Type, Any]]):
def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str:
if isinstance(v, tuple):
if v[1]:
return f"{k}: {v[0]}={v[1]}"
v = v[0]
return f"{k}: {v[0]}"
return f"{k}: {v}"


Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(
variables = [k for k in outputs.keys()]

# TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one.
class Output(
collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)
class Output( # type: ignore
collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore
): # type: ignore
"""
This class can be used in two different places. For multivariate-return entities this class is used
Expand Down Expand Up @@ -450,8 +450,7 @@ def t(a: int, b: str) -> Dict[str, int]: ...
"Tuples should be used to indicate multiple return values, found only one return variable."
)
return OrderedDict(
zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__)
# type: ignore
zip(list(output_name_generator(len(return_annotation.__args__))), return_annotation.__args__) # type: ignore
)
elif isinstance(return_annotation, tuple):
if len(return_annotation) == 1:
Expand Down
36 changes: 18 additions & 18 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,34 +55,34 @@ def __init__(
self._partial = None
if isinstance(python_function_task, functools.partial):
self._partial = python_function_task
python_function_task = self._partial.func
actual_task = self._partial.func
else:
actual_task = python_function_task

if not isinstance(python_function_task, PythonFunctionTask):
if not isinstance(actual_task, PythonFunctionTask):
raise ValueError("Map tasks can only compose of Python Functon Tasks currently")

if len(python_function_task.python_interface.outputs.keys()) > 1:
if len(actual_task.python_interface.outputs.keys()) > 1:
raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs")

self._bound_inputs = set(bound_inputs) if bound_inputs else set()
self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set()
if self._partial:
self._bound_inputs = set(self._partial.keywords.keys())

collection_interface = transform_interface_to_list_interface(
python_function_task.python_interface, self._bound_inputs
)
self._run_task = python_function_task
_, mod, f, _ = tracker.extract_task_module(python_function_task.task_function)
collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs)
self._run_task: PythonFunctionTask = actual_task
_, mod, f, _ = tracker.extract_task_module(actual_task.task_function)
h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest()
name = f"{mod}.map_{f}_{h}"

self._cmd_prefix = None
self._max_concurrency = concurrency
self._min_success_ratio = min_success_ratio
self._array_task_interface = python_function_task.python_interface
if "metadata" not in kwargs and python_function_task.metadata:
kwargs["metadata"] = python_function_task.metadata
if "security_ctx" not in kwargs and python_function_task.security_context:
kwargs["security_ctx"] = python_function_task.security_context
self._cmd_prefix: typing.Optional[typing.List[str]] = None
self._max_concurrency: typing.Optional[int] = concurrency
self._min_success_ratio: typing.Optional[float] = min_success_ratio
self._array_task_interface = actual_task.python_interface
if "metadata" not in kwargs and actual_task.metadata:
kwargs["metadata"] = actual_task.metadata
if "security_ctx" not in kwargs and actual_task.security_context:
kwargs["security_ctx"] = actual_task.security_context
super().__init__(
name=name,
interface=collection_interface,
Expand Down Expand Up @@ -124,7 +124,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
return container_args

def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]):
self._cmd_prefix = cmd # type: ignore
self._cmd_prefix = cmd

@contextmanager
def prepare_target(self):
Expand Down

0 comments on commit 3662803

Please sign in to comment.