From 3798450a43605db5d61d2c9694cc9a22e2da8c1f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 16 Feb 2023 08:49:58 +0800 Subject: [PATCH] Fix mypy errors (#1313) * wip Signed-off-by: Kevin Su * Fix mypy errors Signed-off-by: Kevin Su * Fix mypy errors Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * wip Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix tests Signed-off-by: Kevin Su * fix test Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * Update type Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * Fix tests Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * update dev-requirements.txt Signed-off-by: Kevin Su * Address comment Signed-off-by: Kevin Su * upgrade torch Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su * lint Signed-off-by: Kevin Su --------- Signed-off-by: Kevin Su Signed-off-by: Kevin Su Co-authored-by: Yee Hing Tong --- Makefile | 11 +- dev-requirements.in | 3 + dev-requirements.txt | 14 ++- flytekit/core/base_sql_task.py | 8 +- flytekit/core/base_task.py | 52 ++++---- flytekit/core/class_based_resolver.py | 4 +- flytekit/core/condition.py | 6 +- flytekit/core/container_task.py | 14 +-- flytekit/core/context_manager.py | 20 ++-- flytekit/core/data_persistence.py | 6 +- flytekit/core/docstring.py | 2 +- flytekit/core/gate.py | 5 +- flytekit/core/interface.py | 87 +++++++------- flytekit/core/launch_plan.py | 84 +++++++------ flytekit/core/map_task.py | 14 +-- flytekit/core/node.py | 2 +- flytekit/core/node_creation.py | 9 +- flytekit/core/promise.py | 44 ++++--- flytekit/core/python_auto_container.py | 25 ++-- .../core/python_customized_container_task.py | 6 +- flytekit/core/python_function_task.py | 14 +-- flytekit/core/reference_entity.py | 4 +- flytekit/core/resources.py | 18 +-- flytekit/core/schedule.py | 9 +- flytekit/core/shim_task.py | 14 ++- flytekit/core/task.py | 6 +- flytekit/core/testing.py | 7 +- flytekit/core/tracked_abc.py | 2 +- flytekit/core/tracker.py | 2 +- flytekit/core/type_engine.py | 113 ++++++++++-------- flytekit/core/utils.py | 2 +- flytekit/core/workflow.py | 35 +++--- flytekit/models/literals.py | 2 +- flytekit/types/directory/__init__.py | 2 +- flytekit/types/directory/types.py | 7 +- flytekit/types/file/file.py | 4 +- flytekit/types/numpy/ndarray.py | 2 +- flytekit/types/schema/types.py | 48 ++++---- flytekit/types/schema/types_pandas.py | 8 +- .../types/structured/structured_dataset.py | 30 ++--- .../core/flyte_functools/decorator_source.py | 5 +- .../core/flyte_functools/nested_function.py | 2 +- .../core/flyte_functools/simple_decorator.py | 2 +- .../flyte_functools/stacked_decorators.py | 2 +- .../flyte_functools/unwrapped_decorator.py | 2 +- tests/flytekit/unit/core/test_composition.py | 6 +- tests/flytekit/unit/core/test_conditions.py | 8 +- tests/flytekit/unit/core/test_gate.py | 2 +- tests/flytekit/unit/core/test_imperative.py | 6 +- tests/flytekit/unit/core/test_interface.py | 2 +- tests/flytekit/unit/core/test_launch_plan.py | 2 +- .../flytekit/unit/core/test_node_creation.py | 10 +- .../unit/core/test_python_function_task.py | 2 +- .../unit/core/test_realworld_examples.py | 2 +- tests/flytekit/unit/core/test_references.py | 2 +- .../flytekit/unit/core/test_serialization.py | 12 +- tests/flytekit/unit/core/test_type_engine.py | 6 +- tests/flytekit/unit/core/test_type_hints.py | 12 +- .../unit/core/test_typing_annotation.py | 2 +- tests/flytekit/unit/core/test_workflows.py | 14 +-- 60 files changed, 454 insertions(+), 392 deletions(-) diff --git a/Makefile b/Makefile index 4b3278bec0..f53312a5e6 100644 --- a/Makefile +++ b/Makefile @@ -35,11 +35,12 @@ fmt: ## Format code with black and isort .PHONY: lint lint: ## Run linters - mypy flytekit/core || true - mypy flytekit/types || true - mypy tests/flytekit/unit/core || true - # Exclude setup.py to fix error: Duplicate module named "setup" - mypy plugins --exclude setup.py || true + mypy flytekit/core + mypy flytekit/types + # allow-empty-bodies: Allow empty body in function. + # disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked". + # Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass. + mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core pre-commit run --all-files .PHONY: spellcheck diff --git a/dev-requirements.in b/dev-requirements.in index c2a0a9bdd5..bc98d5dcfb 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -18,3 +18,6 @@ grpcio-status<1.49.0 # we put this constraint while we do not have per-environment requirements files torch<=1.12.1 scikit-learn +types-protobuf +types-croniter +types-mock diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e4eba39fd..98560938ce 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,8 +1,8 @@ # -# This file is autogenerated by pip-compile with python 3.9 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: # -# make dev-requirements.txt +# pip-compile dev-requirements.in # -e file:.#egg=flytekit # via @@ -564,8 +564,12 @@ traitlets==5.9.0 # via # ipython # matplotlib-inline -typed-ast==1.5.4 - # via mypy +types-croniter==1.3.2.2 + # via -r dev-requirements.in +types-mock==5.0.0.2 + # via -r dev-requirements.in +types-protobuf==4.21.0.3 + # via -r dev-requirements.in types-toml==0.10.8.1 # via # -c requirements.txt diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 7fcdc15a50..30b73223a9 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -1,5 +1,5 @@ import re -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Optional, Tuple, Type, TypeVar from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface @@ -22,11 +22,11 @@ def __init__( self, name: str, query_template: str, + task_config: Optional[T] = None, task_type="sql_task", - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - task_config: Optional[T] = None, - outputs: Dict[str, Type] = None, + outputs: Optional[Dict[str, Type]] = None, **kwargs, ): """ diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2cf8032a6f..f163e891e1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -21,10 +21,16 @@ import datetime from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flytekit.configuration import SerializationSettings -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities +from flytekit.core.context_manager import ( + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, + FlyteEntities, +) from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.local_cache import LocalTaskCache from flytekit.core.promise import ( @@ -156,7 +162,7 @@ def __init__( self, task_type: str, name: str, - interface: Optional[_interface_models.TypedInterface] = None, + interface: _interface_models.TypedInterface, metadata: Optional[TaskMetadata] = None, task_type_version=0, security_ctx: Optional[SecurityContext] = None, @@ -174,7 +180,7 @@ def __init__( FlyteEntities.entities.append(self) @property - def interface(self) -> Optional[_interface_models.TypedInterface]: + def interface(self) -> _interface_models.TypedInterface: return self._interface @property @@ -242,8 +248,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr kwargs = translate_inputs_to_literals( ctx, incoming_values=kwargs, - flyte_interface_types=self.interface.inputs, # type: ignore - native_types=self.get_input_types(), + flyte_interface_types=self.interface.inputs, + native_types=self.get_input_types(), # type: ignore ) input_literal_map = _literal_models.LiteralMap(literals=kwargs) @@ -289,8 +295,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs): - return flyte_entity_call_handler(self, *args, **kwargs) + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): raise Exception("not implemented") @@ -334,8 +340,8 @@ def sandbox_execute( """ Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. """ - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() + es = cast(ExecutionState, ctx.execution_state) + b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox() ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() return self.dispatch_execute(ctx, input_literal_map) @@ -384,7 +390,7 @@ def __init__( self, task_type: str, name: str, - task_config: T, + task_config: Optional[T], interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, disable_deck: bool = True, @@ -421,9 +427,13 @@ def __init__( ) else: if self._python_interface.docstring.short_description: - self._docs.short_description = self._python_interface.docstring.short_description + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description if self._python_interface.docstring.long_description: - self._docs.long_description = Description(value=self._python_interface.docstring.long_description) + cast(Documentation, self._docs).long_description = Description( + value=self._python_interface.docstring.long_description + ) # TODO lets call this interface and the other as flyte_interface? @property @@ -434,25 +444,25 @@ def python_interface(self) -> Interface: return self._python_interface @property - def task_config(self) -> T: + def task_config(self) -> Optional[T]: """ Returns the user-specified task config which is used for plugin-specific handling of the task. """ return self._task_config - def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for an input variable by name. """ return self._python_interface.inputs[k] - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]: """ Returns the python type for the specified output variable by name. """ return self._python_interface.outputs[k] - def get_input_types(self) -> Optional[Dict[str, type]]: + def get_input_types(self) -> Dict[str, type]: """ Returns the names and python types as a dictionary for the inputs of this task. """ @@ -498,7 +508,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) # type: ignore ) as exec_ctx: # TODO We could support default values here too - but not part of the plan right now @@ -579,7 +591,7 @@ def dispatch_execute( # After the execute has been successfully completed return outputs_literal_map - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore """ This is the method that will be invoked directly before executing the task method and before all the inputs are converted. One particular case where this is useful is if the context is to be modified for the user process @@ -597,7 +609,7 @@ def execute(self, **kwargs) -> Any: """ pass - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any: """ Post execute is called after the execution has completed, with the user_params and can be used to clean-up, or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op diff --git a/flytekit/core/class_based_resolver.py b/flytekit/core/class_based_resolver.py index d47820f811..49970d5623 100644 --- a/flytekit/core/class_based_resolver.py +++ b/flytekit/core/class_based_resolver.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): def name(self) -> str: return "ClassStorageTaskResolver" - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore return self.mapping def add(self, t: PythonAutoContainerTask): @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: idx = int(loader_args[0]) return self.mapping[idx] - def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore """ This is responsible for turning an instance of a task into args that the load_task function can reconstitute. """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index b5cae86923..76553db702 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP return self._compute_outputs(n) return self._condition - def if_(self, expr: bool) -> Case: + def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case: return self._condition._if(expr) def compute_output_vars(self) -> typing.Optional[typing.List[str]]: @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str: return f"{node_id}.{var}" -def merge_promises(*args: Promise) -> typing.List[Promise]: +def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]: node_vars: typing.Set[typing.Tuple[str, str]] = set() merged_promises: typing.List[Promise] = [] for p in args: @@ -414,7 +414,7 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: - expr, promises = transform_to_boolexpr(c.expr) + expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) n = c.output_promise.ref.node # type: ignore return _core_wf.IfBlock(condition=expr, then_node=n), promises diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index d470fb54fe..677142736c 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask, TaskMetadata @@ -36,16 +36,16 @@ def __init__( name: str, image: str, command: List[str], - inputs: Optional[Dict[str, Type]] = None, + inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, - arguments: List[str] = None, - outputs: Dict[str, Type] = None, + arguments: Optional[List[str]] = None, + outputs: Optional[Dict[str, Type]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - input_data_dir: str = None, - output_data_dir: str = None, + input_data_dir: Optional[str] = None, + output_data_dir: Optional[str] = None, metadata_format: MetadataFormat = MetadataFormat.JSON, - io_strategy: IOStrategy = None, + io_strategy: Optional[IOStrategy] = None, secret_requests: Optional[List[Secret]] = None, **kwargs, ): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 7e4600b3bb..fc8915e338 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -48,7 +48,7 @@ flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[]) if typing.TYPE_CHECKING: - from flytekit.core.base_task import TaskResolverMixin + from flytekit.core.base_task import Task, TaskResolverMixin # Identifier fields use placeholders for registration-time substitution. @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder: def build(self) -> ExecutionParameters: if not isinstance(self.working_dir, utils.AutoDeletingTempDir): - pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True) + pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True) return ExecutionParameters( execution_date=self.execution_date, stats=self.stats, @@ -123,14 +123,14 @@ def build(self) -> ExecutionParameters: ) @staticmethod - def new_builder(current: ExecutionParameters = None) -> Builder: + def new_builder(current: Optional[ExecutionParameters] = None) -> Builder: return ExecutionParameters.Builder(current=current) def with_task_sandbox(self) -> Builder: prefix = self.working_directory if isinstance(self.working_directory, utils.AutoDeletingTempDir): prefix = self.working_directory.name - task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) + task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore p = pathlib.Path(task_sandbox_dir) cp_dir = p.joinpath("__cp") cp_dir.mkdir(exist_ok=True) @@ -287,7 +287,7 @@ def get(self, key: str) -> typing.Any: """ Returns task specific context if present else raise an error. The returned context will match the key """ - return self.__getattr__(attr_name=key) + return self.__getattr__(attr_name=key) # type: ignore class SecretsManager(object): @@ -467,14 +467,14 @@ class Mode(Enum): LOCAL_TASK_EXECUTION = 3 mode: Optional[ExecutionState.Mode] - working_dir: os.PathLike + working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] branch_eval_mode: Optional[BranchEvalMode] user_space_params: Optional[ExecutionParameters] def __init__( self, - working_dir: os.PathLike, + working_dir: Union[os.PathLike, str], mode: Optional[ExecutionState.Mode] = None, engine_dir: Optional[Union[os.PathLike, str]] = None, branch_eval_mode: Optional[BranchEvalMode] = None, @@ -607,7 +607,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params) @staticmethod - def current_context() -> Optional[FlyteContext]: + def current_context() -> FlyteContext: """ This method exists only to maintain backwards compatibility. Please use ``FlyteContextManager.current_context()`` instead. @@ -639,7 +639,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig """ from flytekit.deck.deck import _get_deck - return _get_deck(self.execution_state.user_space_params) + return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params) @dataclass class Builder(object): @@ -852,7 +852,7 @@ class FlyteEntities(object): registration process """ - entities = [] + entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore FlyteContextManager.initialize() diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index d407b3528b..d48ce45ce1 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -58,7 +58,7 @@ class DataPersistence(object): Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture """ - def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs): + def __init__(self, name: str = "", default_prefix: typing.Optional[str] = None, **kwargs): self._name = name self._default_prefix = default_prefix @@ -142,7 +142,7 @@ def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], fo cls._PLUGINS[protocol] = plugin @staticmethod - def get_protocol(url: str): + def get_protocol(url: str) -> str: # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 parts = re.split(r"(\:\:|\://)", url, 1) if len(parts) > 1: @@ -458,7 +458,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): f"Original exception: {str(ex)}" ) - def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False): + def put_data(self, local_path: str, remote_path: str, is_multipart=False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// diff --git a/flytekit/core/docstring.py b/flytekit/core/docstring.py index 420f26f8f5..fa9d9caec2 100644 --- a/flytekit/core/docstring.py +++ b/flytekit/core/docstring.py @@ -4,7 +4,7 @@ class Docstring(object): - def __init__(self, docstring: str = None, callable_: Callable = None): + def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None): if docstring is not None: self._parsed_docstring = parse(docstring) else: diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index b6cb7ca2b6..bc3ab1d3fd 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -53,7 +53,7 @@ def __init__( ) else: # We don't know how to find the python interface here, approve() sets it below, See the code. - self._python_interface = None + self._python_interface = None # type: ignore @property def name(self) -> str: @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return p # Assume this is an approval operation since that's the only remaining option. - msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?" + msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?" proceed = click.confirm(msg, default=True) if proceed: # We need to return a promise here, and a promise is what should've been passed in by the call in approve() @@ -167,6 +167,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st raise ValueError("You can't use approval on a task that doesn't return anything.") ctx = FlyteContextManager.current_context() + upstream_item = typing.cast(Promise, upstream_item) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: if not upstream_item.ref.node.flyte_entity.python_interface: raise ValueError( diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 954c1ae409..3c24e65db2 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -5,7 +5,7 @@ import inspect import typing from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union, cast from typing_extensions import Annotated, get_args, get_origin, get_type_hints @@ -28,8 +28,8 @@ class Interface(object): def __init__( self, - inputs: typing.Optional[typing.Dict[str, Union[Type, Tuple[Type, Any]], None]] = None, - outputs: typing.Optional[typing.Dict[str, Type]] = None, + inputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Tuple[Type, Any]]]] = None, + outputs: Union[Optional[Dict[str, Type]], Optional[Dict[str, Optional[Type]]]] = None, output_tuple_name: Optional[str] = None, docstring: Optional[Docstring] = None, ): @@ -43,21 +43,21 @@ def __init__( primarily used when handling one-element NamedTuples. :param docstring: Docstring of the annotated @task or @workflow from which the interface derives from. """ - self._inputs = {} + self._inputs: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]] = {} # type: ignore if inputs: for k, v in inputs.items(): - if isinstance(v, Tuple) and len(v) > 1: - self._inputs[k] = v + if type(v) is tuple and len(cast(Tuple, v)) > 1: + self._inputs[k] = v # type: ignore else: - self._inputs[k] = (v, None) - self._outputs = outputs if outputs else {} + self._inputs[k] = (v, None) # type: ignore + self._outputs = outputs if outputs else {} # type: ignore self._output_tuple_name = output_tuple_name if outputs: variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): + class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -90,7 +90,7 @@ def __rshift__(self, *args, **kwargs): self._docstring = docstring @property - def output_tuple(self) -> Optional[Type[collections.namedtuple]]: + def output_tuple(self) -> Type[collections.namedtuple]: # type: ignore return self._output_tuple_class @property @@ -98,7 +98,7 @@ def output_tuple_name(self) -> Optional[str]: return self._output_tuple_name @property - def inputs(self) -> typing.Dict[str, Type]: + def inputs(self) -> Dict[str, type]: r = {} for k, v in self._inputs.items(): r[k] = v[0] @@ -111,8 +111,8 @@ def output_names(self) -> Optional[List[str]]: return None @property - def inputs_with_defaults(self) -> typing.Dict[str, Tuple[Type, Any]]: - return self._inputs + def inputs_with_defaults(self) -> Dict[str, Tuple[Type, Any]]: + return cast(Dict[str, Tuple[Type, Any]], self._inputs) @property def default_inputs_as_kwargs(self) -> Dict[str, Any]: @@ -120,13 +120,13 @@ def default_inputs_as_kwargs(self) -> Dict[str, Any]: @property def outputs(self) -> typing.Dict[str, type]: - return self._outputs + return self._outputs # type: ignore @property def docstring(self) -> Optional[Docstring]: return self._docstring - def remove_inputs(self, vars: List[str]) -> Interface: + def remove_inputs(self, vars: Optional[List[str]]) -> Interface: """ This method is useful in removing some variables from the Flyte backend inputs specification, as these are implicit local only inputs or will be supplied by the library at runtime. For example, spark-session etc @@ -151,7 +151,7 @@ def with_inputs(self, extra_inputs: Dict[str, Type]) -> Interface: for k, v in extra_inputs.items(): if k in new_inputs: raise ValueError(f"Input {k} cannot be added as it already exists in the interface") - new_inputs[k] = v + cast(Dict[str, Type], new_inputs)[k] = v return Interface(new_inputs, self._outputs, docstring=self.docstring) def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: @@ -240,7 +240,7 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] + om[k] = typing.List[v] # type: ignore return om # type: ignore @@ -255,18 +255,20 @@ def transform_interface_to_list_interface(interface: Interface) -> Interface: return Interface(inputs=map_inputs, outputs=map_outputs) -def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T], Annotated]: +def _change_unrecognized_type_to_pickle(t: Type[T]) -> typing.Union[Tuple[Type[T]], Type[T]]: try: if hasattr(t, "__origin__") and hasattr(t, "__args__"): - if get_origin(t) is list: - return typing.List[_change_unrecognized_type_to_pickle(t.__args__[0])] - elif get_origin(t) is dict and t.__args__[0] == str: - return typing.Dict[str, _change_unrecognized_type_to_pickle(t.__args__[1])] - elif get_origin(t) is typing.Union: - return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] - elif get_origin(t) is Annotated: + ot = get_origin(t) + args = getattr(t, "__args__") + if ot is list: + return typing.List[_change_unrecognized_type_to_pickle(args[0])] # type: ignore + elif ot is dict and args[0] == str: + return typing.Dict[str, _change_unrecognized_type_to_pickle(args[1])] # type: ignore + elif ot is typing.Union: + return typing.Union[tuple(_change_unrecognized_type_to_pickle(v) for v in get_args(t))] # type: ignore + elif ot is Annotated: base_type, *config = get_args(t) - return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] + return Annotated[(_change_unrecognized_type_to_pickle(base_type), *config)] # type: ignore TypeEngine.get_transformer(t) except ValueError: logger.warning( @@ -294,12 +296,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc outputs = extract_return_annotation(return_annotation) for k, v in outputs.items(): outputs[k] = _change_unrecognized_type_to_pickle(v) # type: ignore - inputs = OrderedDict() + inputs: Dict[str, Tuple[Type, Any]] = OrderedDict() for k, v in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) default = v.default if v.default is not inspect.Parameter.empty else None # Inputs with default values are currently ignored, we may want to look into that in the future - inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) + inputs[k] = (_change_unrecognized_type_to_pickle(annotation), default) # type: ignore # This is just for typing.NamedTuples - in those cases, the user can select a name to call the NamedTuple. We # would like to preserve that name in our custom collections.namedtuple. @@ -325,23 +327,24 @@ def transform_variable_map( if variable_map: for k, v in variable_map.items(): res[k] = transform_type(v, descriptions.get(k, k)) - sub_type: Type[T] = v + sub_type: type = v if hasattr(v, "__origin__") and hasattr(v, "__args__"): - if v.__origin__ is list: - sub_type = v.__args__[0] - elif v.__origin__ is dict: - sub_type = v.__args__[1] - if hasattr(sub_type, "__origin__") and sub_type.__origin__ is FlytePickle: - if hasattr(sub_type.python_type(), "__name__"): - res[k].type.metadata = {"python_class_name": sub_type.python_type().__name__} - elif hasattr(sub_type.python_type(), "_name"): + if getattr(v, "__origin__") is list: + sub_type = getattr(v, "__args__")[0] + elif getattr(v, "__origin__") is dict: + sub_type = getattr(v, "__args__")[1] + if hasattr(sub_type, "__origin__") and getattr(sub_type, "__origin__") is FlytePickle: + original_type = cast(FlytePickle, sub_type).python_type() + if hasattr(original_type, "__name__"): + res[k].type.metadata = {"python_class_name": original_type.__name__} + elif hasattr(original_type, "_name"): # If the class doesn't have the __name__ attribute, like typing.Sequence, use _name instead. - res[k].type.metadata = {"python_class_name": sub_type.python_type()._name} + res[k].type.metadata = {"python_class_name": original_type._name} return res -def transform_type(x: type, description: str = None) -> _interface_models.Variable: +def transform_type(x: type, description: Optional[str] = None) -> _interface_models.Variable: return _interface_models.Variable(type=TypeEngine.to_literal_type(x), description=description) @@ -393,13 +396,13 @@ def t(a: int, b: str) -> Dict[str, int]: ... # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python - if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): + if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore # isinstance / issubclass does not work for Namedtuple. # Options 1 and 2 bases = return_annotation.__bases__ # type: ignore if len(bases) == 1 and bases[0] == tuple and hasattr(return_annotation, "_fields"): logger.debug(f"Task returns named tuple {return_annotation}") - return dict(get_type_hints(return_annotation, include_extras=True)) + return dict(get_type_hints(cast(Type, return_annotation), include_extras=True)) if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore # Handle option 3 @@ -419,7 +422,7 @@ def t(a: int, b: str) -> Dict[str, int]: ... else: # Handle all other single return types logger.debug(f"Task returns unnamed native tuple {return_annotation}") - return {default_output_name(): return_annotation} + return {default_output_name(): cast(Type, return_annotation)} def remap_shared_output_descriptions(output_descriptions: Dict[str, str], outputs: Dict[str, Type]) -> Dict[str, str]: diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 550dc1919e..86011f1253 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -74,7 +74,7 @@ def wf(a: int, c: str) -> str: # The reason we cache is simply because users may get the default launch plan twice for a single Workflow. We # don't want to create two defaults, could be confusing. - CACHE = {} + CACHE: typing.Dict[str, LaunchPlan] = {} @staticmethod def get_default_launch_plan(ctx: FlyteContext, workflow: _annotated_workflow.WorkflowBase) -> LaunchPlan: @@ -107,16 +107,16 @@ def create( cls, name: str, workflow: _annotated_workflow.WorkflowBase, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: ctx = FlyteContextManager.current_context() default_inputs = default_inputs or {} @@ -130,7 +130,7 @@ def create( temp_inputs = {} for k, v in default_inputs.items(): temp_inputs[k] = (workflow.python_interface.inputs[k], v) - temp_interface = Interface(inputs=temp_inputs, outputs={}) + temp_interface = Interface(inputs=temp_inputs, outputs={}) # type: ignore temp_signature = transform_inputs_to_parameters(ctx, temp_interface) wf_signature_parameters._parameters.update(temp_signature.parameters) @@ -185,16 +185,16 @@ def get_or_create( cls, workflow: _annotated_workflow.WorkflowBase, name: Optional[str] = None, - default_inputs: Dict[str, Any] = None, - fixed_inputs: Dict[str, Any] = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, - auth_role: _common_models.AuthRole = None, + default_inputs: Optional[Dict[str, Any]] = None, + fixed_inputs: Optional[Dict[str, Any]] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, + auth_role: Optional[_common_models.AuthRole] = None, ) -> LaunchPlan: """ This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not @@ -298,13 +298,13 @@ def __init__( workflow: _annotated_workflow.WorkflowBase, parameters: _interface_models.ParameterMap, fixed_inputs: _literal_models.LiteralMap, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - max_parallelism: typing.Optional[int] = None, - security_context: typing.Optional[security.SecurityContext] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ): self._name = name self._workflow = workflow @@ -313,7 +313,7 @@ def __init__( self._parameters = _interface_models.ParameterMap(parameters=parameters) self._fixed_inputs = fixed_inputs # See create() for additional information - self._saved_inputs = {} + self._saved_inputs: Dict[str, Any] = {} self._schedule = schedule self._notifications = notifications or [] @@ -328,16 +328,15 @@ def __init__( def clone_with( self, name: str, - parameters: _interface_models.ParameterMap = None, - fixed_inputs: _literal_models.LiteralMap = None, - schedule: _schedule_model.Schedule = None, - notifications: List[_common_models.Notification] = None, - labels: _common_models.Labels = None, - annotations: _common_models.Annotations = None, - raw_output_data_config: _common_models.RawOutputDataConfig = None, - auth_role: _common_models.AuthRole = None, - max_parallelism: int = None, - security_context: typing.Optional[security.SecurityContext] = None, + parameters: Optional[_interface_models.ParameterMap] = None, + fixed_inputs: Optional[_literal_models.LiteralMap] = None, + schedule: Optional[_schedule_model.Schedule] = None, + notifications: Optional[List[_common_models.Notification]] = None, + labels: Optional[_common_models.Labels] = None, + annotations: Optional[_common_models.Annotations] = None, + raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, + max_parallelism: Optional[int] = None, + security_context: Optional[security.SecurityContext] = None, ) -> LaunchPlan: return LaunchPlan( name=name, @@ -349,7 +348,6 @@ def clone_with( labels=labels or self.labels, annotations=annotations or self.annotations, raw_output_data_config=raw_output_data_config or self.raw_output_data_config, - auth_role=auth_role or self._auth_role, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, ) @@ -407,11 +405,11 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] return self._raw_output_data_config @property - def max_parallelism(self) -> typing.Optional[int]: + def max_parallelism(self) -> Optional[int]: return self._max_parallelism @property - def security_context(self) -> typing.Optional[security.SecurityContext]: + def security_context(self) -> Optional[security.SecurityContext]: return self._security_context def construct_node_metadata(self) -> _workflow_model.NodeMetadata: diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 3b5c0a09ca..48d0f0b335 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -7,7 +7,7 @@ import typing from contextlib import contextmanager from itertools import count -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional from flytekit.configuration import SerializationSettings from flytekit.core import tracker @@ -36,8 +36,8 @@ class MapPythonTask(PythonTask): def __init__( self, python_function_task: PythonFunctionTask, - concurrency: int = None, - min_success_ratio: float = None, + concurrency: Optional[int] = None, + min_success_ratio: Optional[float] = None, **kwargs, ): """ @@ -149,8 +149,8 @@ def _compute_array_job_index() -> int: environment variable and the offset (if one's set). The offset will be set and used when the user request that the job runs in a number of slots less than the size of the input. """ - return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", 0)) + int( - os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME")) + return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int( + os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0") ) @property @@ -168,7 +168,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs return self._run_task.interface.outputs - def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: + def get_type_for_output_var(self, k: str, v: Any) -> type: """ We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this interface to construct outputs. Each instance of an container_array task will however produce outputs @@ -181,7 +181,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]: return self._python_interface.outputs[k] return self._run_task._python_interface.outputs[k] - def _execute_map_task(self, ctx: FlyteContext, **kwargs) -> Any: + def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: """ This is called during ExecutionState.Mode.TASK_EXECUTION executions, that is executions orchestrated by the Flyte platform. Individual instances of the map task, aka array task jobs are passed the full set of inputs but diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 220301c402..73f951d721 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -131,7 +131,7 @@ def with_overrides(self, *args, **kwargs): def _convert_resource_overrides( resources: typing.Optional[Resources], resource_name: str -) -> [_resources_model.ResourceEntry]: +) -> typing.List[_resources_model.ResourceEntry]: if resources is None: return [] diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index de33393c13..62065f6869 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,7 +1,6 @@ from __future__ import annotations -import collections -from typing import TYPE_CHECKING, Type, Union +from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext @@ -21,7 +20,7 @@ def create_node( entity: Union[PythonTask, LaunchPlan, WorkflowBase, RemoteEntity], *args, **kwargs -) -> Union[Node, VoidPromise, Type[collections.namedtuple]]: +) -> Union[Node, VoidPromise]: """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any @@ -173,9 +172,9 @@ def sub_wf(): if len(output_names) == 1: # See explanation above for why we still tupletize a single element. - return entity.python_interface.output_tuple(results) + return entity.python_interface.output_tuple(results) # type: ignore - return entity.python_interface.output_tuple(*results) + return entity.python_interface.output_tuple(*results) # type: ignore else: raise Exception(f"Cannot use explicit run to call Flyte entities {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 5b9ea93656..7236286f29 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -10,7 +10,13 @@ from flytekit.core import context_manager as _flyte_context from flytekit.core import interface as flyte_interface from flytekit.core import type_engine -from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ( + BranchEvalMode, + ExecutionParameters, + ExecutionState, + FlyteContext, + FlyteContextManager, +) from flytekit.core.interface import Interface from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine @@ -81,7 +87,7 @@ def extract_value( if lt.collection_type is None: raise TypeError(f"Not a collection type {flyte_literal_type} but got a list {input_val}") try: - sub_type = ListTransformer.get_sub_type(python_type) + sub_type: type = ListTransformer.get_sub_type(python_type) except ValueError: if len(input_val) == 0: raise @@ -348,7 +354,7 @@ def __hash__(self): return hash(id(self)) def __rshift__(self, other: Union[Promise, VoidPromise]): - if not self.is_ready: + if not self.is_ready and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -408,10 +414,10 @@ def is_false(self) -> ComparisonExpression: def is_true(self): return self.is_(True) - def __eq__(self, other) -> ComparisonExpression: + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) - def __ne__(self, other) -> ComparisonExpression: + def __ne__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.NE, other) def __gt__(self, other) -> ComparisonExpression: @@ -455,7 +461,7 @@ def __str__(self): def create_native_named_tuple( ctx: FlyteContext, - promises: Optional[Union[Promise, List[Promise]]], + promises: Union[Tuple[Promise], Promise, VoidPromise, None], entity_interface: Interface, ) -> Optional[Tuple]: """ @@ -476,7 +482,7 @@ def create_native_named_tuple( except Exception as e: raise AssertionError(f"Failed to convert value of output {k}, expected type {v}.") from e - if len(promises) == 0: + if len(cast(Tuple[Promise], promises)) == 0: return None named_tuple_name = "DefaultNamedTupleOutput" @@ -484,7 +490,7 @@ def create_native_named_tuple( named_tuple_name = entity_interface.output_tuple_name outputs = {} - for p in promises: + for p in cast(Tuple[Promise], promises): if not isinstance(p, Promise): raise AssertionError( "Workflow outputs can only be promises that are returned by tasks. Found a value of" @@ -497,8 +503,8 @@ def create_native_named_tuple( raise AssertionError(f"Failed to convert value of output {p.var}, expected type {t}.") from e # Should this class be part of the Interface? - t = collections.namedtuple(named_tuple_name, list(outputs.keys())) - return t(**outputs) + nt = collections.namedtuple(named_tuple_name, list(outputs.keys())) # type: ignore + return nt(**outputs) # To create a class that is a named tuple, we might have to create namedtuplemeta and manipulate the tuple @@ -542,7 +548,7 @@ def create_task_output( named_tuple_name = entity_interface.output_tuple_name # Should this class be part of the Interface? - class Output(collections.namedtuple(named_tuple_name, variables)): + class Output(collections.namedtuple(named_tuple_name, variables)): # type: ignore def with_overrides(self, *args, **kwargs): val = self.__getattribute__(self._fields[0]) val.with_overrides(*args, **kwargs) @@ -601,7 +607,7 @@ def binding_data_from_python_std( if expected_literal_type.collection_type is None: raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") - sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None + sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value @@ -683,7 +689,7 @@ def ref(self) -> Optional[NodeOutput]: return self._ref def __rshift__(self, other: Union[Promise, VoidPromise]): - if self.ref: + if self.ref and other.ref: self.ref.node.runs_before(other.ref.node) return other @@ -1019,11 +1025,13 @@ def create_and_link_node( class LocallyExecutable(Protocol): - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... -def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): +def flyte_entity_call_handler( + entity: SupportsNodeCreation, *args, **kwargs +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1076,7 +1084,7 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) ) ) as child_ctx: - cast(FlyteContext, child_ctx).user_space_params._decks = [] + cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs) @@ -1086,7 +1094,9 @@ def flyte_entity_call_handler(entity: SupportsNodeCreation, *args, **kwargs): else: raise Exception(f"Received an output when workflow local execution expected None. Received: {result}") - if (1 < expected_outputs == len(result)) or (result is not None and expected_outputs == 1): + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( + result is not None and expected_outputs == 1 + ): return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface) raise ValueError( diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 2d05df3c3d..113f94a998 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -3,8 +3,7 @@ import importlib import re from abc import ABC -from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast from flyteidl.core import tasks_pb2 as _core_task from kubernetes.client import ApiClient @@ -120,7 +119,7 @@ def __init__( self.pod_template = pod_template @property - def task_resolver(self) -> Optional[TaskResolverMixin]: + def task_resolver(self) -> TaskResolverMixin: return self._task_resolver @property @@ -208,23 +207,23 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain ) def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any]: - containers = self.pod_template.pod_spec.containers + containers = cast(PodTemplate, self.pod_template).pod_spec.containers primary_exists = False for container in containers: - if container.name == self.pod_template.primary_container_name: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: primary_exists = True break if not primary_exists: # insert a placeholder primary container if it is not defined in the pod spec. - containers.append(V1Container(name=self.pod_template.primary_container_name)) + containers.append(V1Container(name=cast(PodTemplate, self.pod_template).primary_container_name)) final_containers = [] for container in containers: # In the case of the primary container, we overwrite specific container attributes # with the default values used in the regular Python task. # The attributes include: image, command, args, resource, and env (env is unioned) - if container.name == self.pod_template.primary_container_name: + if container.name == cast(PodTemplate, self.pod_template).primary_container_name: sdk_default_container = self._get_container(settings) container.image = sdk_default_container.image # clear existing commands @@ -244,9 +243,9 @@ def _serialize_pod_spec(self, settings: SerializationSettings) -> Dict[str, Any] container.env or [] ) final_containers.append(container) - self.pod_template.pod_spec.containers = final_containers + cast(PodTemplate, self.pod_template).pod_spec.containers = final_containers - return ApiClient().sanitize_for_serialization(self.pod_template.pod_spec) + return ApiClient().sanitize_for_serialization(cast(PodTemplate, self.pod_template).pod_spec) def get_k8s_pod(self, settings: SerializationSettings) -> _task_model.K8sPod: if self.pod_template is None: @@ -274,14 +273,14 @@ class DefaultTaskResolver(TrackedInstance, TaskResolverMixin): def name(self) -> str: return "DefaultTaskResolver" - def load_task(self, loader_args: List[Union[T, ModuleType]]) -> PythonAutoContainerTask: + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: _, task_module, _, task_name, *_ = loader_args - task_module = importlib.import_module(task_module) + task_module = importlib.import_module(name=task_module) # type: ignore task_def = getattr(task_module, task_name) return task_def - def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore from flytekit.core.python_function_task import PythonFunctionTask if isinstance(task, PythonFunctionTask): @@ -291,7 +290,7 @@ def loader_args(self, settings: SerializationSettings, task: PythonAutoContainer _, m, t, _ = extract_task_module(task) return ["task-module", m, "task-name", t] - def get_all_tasks(self) -> List[PythonAutoContainerTask]: + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise Exception("should not be needed") diff --git a/flytekit/core/python_customized_container_task.py b/flytekit/core/python_customized_container_task.py index eee0dce9b8..07493886a2 100644 --- a/flytekit/core/python_customized_container_task.py +++ b/flytekit/core/python_customized_container_task.py @@ -21,7 +21,7 @@ TC = TypeVar("TC") -class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): +class PythonCustomizedContainerTask(ExecutableTemplateShimTask, PythonTask[TC]): # type: ignore """ Please take a look at the comments for :py:class`flytekit.extend.ExecutableTemplateShimTask` as well. This class should be subclassed and a custom Executor provided as a default to this parent class constructor @@ -229,7 +229,7 @@ def name(self) -> str: # The return type of this function is different, it should be a Task, but it's not because it doesn't make # sense for ExecutableTemplateShimTask to inherit from Task. - def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: + def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: # type: ignore logger.info(f"Task template loader args: {loader_args}") ctx = FlyteContext.current_context() task_template_local_path = os.path.join(ctx.execution_state.working_dir, "task_template.pb") # type: ignore @@ -240,7 +240,7 @@ def load_task(self, loader_args: List[str]) -> ExecutableTemplateShimTask: executor_class = load_object_from_module(loader_args[1]) return ExecutableTemplateShimTask(task_template_model, executor_class) - def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: + def loader_args(self, settings: SerializationSettings, t: PythonCustomizedContainerTask) -> List[str]: # type: ignore return ["{{.taskTemplatePath}}", f"{t.executor_type.__module__}.{t.executor_type.__name__}"] def get_all_tasks(self) -> List[Task]: diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 81f6739a39..90b10cbc36 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -17,7 +17,7 @@ from abc import ABC from collections import OrderedDict from enum import Enum -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, TypeVar, Union, cast from flytekit.core.base_task import Task, TaskResolverMixin from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager @@ -43,7 +43,7 @@ T = TypeVar("T") -class PythonInstanceTask(PythonAutoContainerTask[T], ABC): +class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore """ This class should be used as the base class for all Tasks that do not have a user defined function body, but have a platform defined execute method. (Execute needs to be overridden). This base class ensures that the module loader @@ -72,7 +72,7 @@ def __init__( super().__init__(name=name, task_config=task_config, task_type=task_type, task_resolver=task_resolver, **kwargs) -class PythonFunctionTask(PythonAutoContainerTask[T]): +class PythonFunctionTask(PythonAutoContainerTask[T]): # type: ignore """ A Python Function task should be used as the base for all extensions that have a python function. It will automatically detect interface of the python function and when serialized on the hosted Flyte platform handles the @@ -193,10 +193,10 @@ def compile_into_workflow( from flytekit.tools.translator import get_serializable self._create_and_cache_dynamic_workflow() - self._wf.compile(**kwargs) + cast(PythonFunctionWorkflow, self._wf).compile(**kwargs) wf = self._wf - model_entities = OrderedDict() + model_entities: OrderedDict = OrderedDict() # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. @@ -263,12 +263,12 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() - function_outputs = self._wf.execute(**kwargs) + function_outputs = cast(PythonFunctionWorkflow, self._wf).execute(**kwargs) if isinstance(function_outputs, VoidPromise) or function_outputs is None: return VoidPromise(self.name) - if len(self._wf.python_interface.outputs) == 0: + if len(cast(PythonFunctionWorkflow, self._wf).python_interface.outputs) == 0: raise FlyteValueException(function_outputs, "Interface output should've been VoidPromise or None.") # TODO: This will need to be cleaned up when we revisit top-level tuple support. diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 7247457d86..de386fa159 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -21,7 +21,7 @@ from flytekit.models.core import workflow as _workflow_model -@dataclass +@dataclass # type: ignore class Reference(ABC): project: str domain: str @@ -72,7 +72,7 @@ class ReferenceEntity(object): def __init__( self, reference: Union[WorkflowReference, TaskReference, LaunchPlanReference], - inputs: Optional[Dict[str, Union[Type[Any], Tuple[Type[Any], Any]]]], + inputs: Dict[str, Type], outputs: Dict[str, Type], ): if ( diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 6280604246..4cf2523f6a 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -35,26 +35,26 @@ class Resources(object): @dataclass class ResourceSpec(object): - requests: Optional[Resources] = None - limits: Optional[Resources] = None + requests: Resources + limits: Resources -_ResouceName = task_models.Resources.ResourceName +_ResourceName = task_models.Resources.ResourceName _ResourceEntry = task_models.Resources.ResourceEntry -def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: +def _convert_resources_to_resource_entries(resources: Resources) -> List[_ResourceEntry]: # type: ignore resource_entries = [] if resources.cpu is not None: - resource_entries.append(_ResourceEntry(name=_ResouceName.CPU, value=resources.cpu)) + resource_entries.append(_ResourceEntry(name=_ResourceName.CPU, value=resources.cpu)) if resources.mem is not None: - resource_entries.append(_ResourceEntry(name=_ResouceName.MEMORY, value=resources.mem)) + resource_entries.append(_ResourceEntry(name=_ResourceName.MEMORY, value=resources.mem)) if resources.gpu is not None: - resource_entries.append(_ResourceEntry(name=_ResouceName.GPU, value=resources.gpu)) + resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=resources.gpu)) if resources.storage is not None: - resource_entries.append(_ResourceEntry(name=_ResouceName.STORAGE, value=resources.storage)) + resource_entries.append(_ResourceEntry(name=_ResourceName.STORAGE, value=resources.storage)) if resources.ephemeral_storage is not None: - resource_entries.append(_ResourceEntry(name=_ResouceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) + resource_entries.append(_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=resources.ephemeral_storage)) return resource_entries diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 7addc89197..93116d0720 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -6,6 +6,7 @@ import datetime import re as _re +from typing import Optional import croniter as _croniter @@ -52,7 +53,11 @@ class CronSchedule(_schedule_models.Schedule): _OFFSET_PATTERN = _re.compile("([-+]?)P([-+0-9YMWD]+)?(T([-+0-9HMS.,]+)?)?") def __init__( - self, cron_expression: str = None, schedule: str = None, offset: str = None, kickoff_time_input_arg: str = None + self, + cron_expression: Optional[str] = None, + schedule: Optional[str] = None, + offset: Optional[str] = None, + kickoff_time_input_arg: Optional[str] = None, ): """ :param str cron_expression: This should be a cron expression in AWS style.Shouldn't be used in case of native scheduler. @@ -161,7 +166,7 @@ class FixedRate(_schedule_models.Schedule): See the :std:ref:`fixed rate intervals` chapter in the cookbook for additional usage examples. """ - def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: str = None): + def __init__(self, duration: datetime.timedelta, kickoff_time_input_arg: Optional[str] = None): """ :param datetime.timedelta duration: :param str kickoff_time_input_arg: diff --git a/flytekit/core/shim_task.py b/flytekit/core/shim_task.py index d8d18293c5..f96db3e49c 100644 --- a/flytekit/core/shim_task.py +++ b/flytekit/core/shim_task.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Any, Generic, Type, TypeVar, Union +from typing import Any, Generic, Optional, Type, TypeVar, Union, cast -from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine from flytekit.loggers import logger @@ -47,7 +47,7 @@ def name(self) -> str: if self._task_template is not None: return self._task_template.id.name # if not access the subclass's name - return self._name + return self._name # type: ignore @property def task_template(self) -> _task_model.TaskTemplate: @@ -67,13 +67,13 @@ def execute(self, **kwargs) -> Any: """ return self.executor.execute_from_model(self.task_template, **kwargs) - def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: + def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ return user_params - def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: + def post_execute(self, _: Optional[ExecutionParameters], rval: Any) -> Any: """ This function is a stub, just here to keep dispatch_execute compatibility between this class and PythonTask. """ @@ -92,7 +92,9 @@ def dispatch_execute( # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( - ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params)) + ctx.with_execution_state( + cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params) + ) ) as exec_ctx: # Added: Have to reverse the Python interface from the task template Flyte interface # See docstring for more details. diff --git a/flytekit/core/task.py b/flytekit/core/task.py index 28c5b5def7..b107aafe12 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -89,7 +89,7 @@ def task( requests: Optional[Resources] = None, limits: Optional[Resources] = None, secret_requests: Optional[List[Secret]] = None, - execution_mode: Optional[PythonFunctionTask.ExecutionBehavior] = PythonFunctionTask.ExecutionBehavior.DEFAULT, + execution_mode: PythonFunctionTask.ExecutionBehavior = PythonFunctionTask.ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, docs: Optional[Documentation] = None, disable_deck: bool = True, @@ -225,7 +225,7 @@ def wrapper(fn) -> PythonFunctionTask: return wrapper -class ReferenceTask(ReferenceEntity, PythonFunctionTask): +class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore """ This is a reference task, the body of the function passed in through the constructor will never be used, only the signature of the function will be. The signature should also match the signature of the task you're referencing, @@ -233,7 +233,7 @@ class ReferenceTask(ReferenceEntity, PythonFunctionTask): """ def __init__( - self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type] + self, project: str, domain: str, name: str, version: str, inputs: Dict[str, type], outputs: Dict[str, Type] ): super().__init__(TaskReference(project, domain, name, version), inputs, outputs) diff --git a/flytekit/core/testing.py b/flytekit/core/testing.py index 772a4b6df6..f1a0fec7de 100644 --- a/flytekit/core/testing.py +++ b/flytekit/core/testing.py @@ -1,3 +1,4 @@ +import typing from contextlib import contextmanager from typing import Union from unittest.mock import MagicMock @@ -9,7 +10,7 @@ @contextmanager -def task_mock(t: PythonTask) -> MagicMock: +def task_mock(t: PythonTask) -> typing.Generator[MagicMock, None, None]: """ Use this method to mock a task declaration. It can mock any Task in Flytekit as long as it has a python native interface associated with it. @@ -41,9 +42,9 @@ def _log(*args, **kwargs): return m(*args, **kwargs) _captured_fn = t.execute - t.execute = _log + t.execute = _log # type: ignore yield m - t.execute = _captured_fn + t.execute = _captured_fn # type: ignore def patch(target: Union[PythonTask, WorkflowBase, ReferenceEntity]): diff --git a/flytekit/core/tracked_abc.py b/flytekit/core/tracked_abc.py index bad4f8c555..3c39d3725c 100644 --- a/flytekit/core/tracked_abc.py +++ b/flytekit/core/tracked_abc.py @@ -3,7 +3,7 @@ from flytekit.core.tracker import TrackedInstance -class FlyteTrackedABC(type(TrackedInstance), type(ABC)): +class FlyteTrackedABC(type(TrackedInstance), type(ABC)): # type: ignore """ This class exists because if you try to inherit from abc.ABC and TrackedInstance by itself, you'll get the well-known ``TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 2a203d4861..23ff7c9222 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -179,7 +179,7 @@ class _ModuleSanitizer(object): def __init__(self): self._module_cache = {} - def _resolve_abs_module_name(self, path: str, package_root: str) -> str: + def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] = None) -> str: """ Recursively finds the root python package under-which basename exists """ diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 23b83abdc3..6ad8cebc3b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -117,7 +117,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented") @abstractmethod - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: """ Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised :param ctx: FlyteContext @@ -161,7 +161,7 @@ def __init__( self._to_literal_transformer = to_literal_transformer self._from_literal_transformer = from_literal_transformer - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -206,7 +206,7 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC): def __init__(self, name: str, t: Type[T]): super().__init__(name, t) - def get_literal_type(self, t: Type[T] = None) -> LiteralType: + def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently") def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: @@ -367,11 +367,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A return None return self._serialize_flyte_type(python_val, get_args(python_type)[0]) - if hasattr(python_type, "__origin__") and python_type.__origin__ is list: - return [self._serialize_flyte_type(v, python_type.__args__[0]) for v in python_val] + if hasattr(python_type, "__origin__") and get_origin(python_type) is list: + return [self._serialize_flyte_type(v, get_args(python_type)[0]) for v in cast(list, python_val)] - if hasattr(python_type, "__origin__") and python_type.__origin__ is dict: - return {k: self._serialize_flyte_type(v, python_type.__args__[1]) for k, v in python_val.items()} + if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: + return { + k: self._serialize_flyte_type(v, get_args(python_type)[1]) for k, v in cast(dict, python_val).items() + } if not dataclasses.is_dataclass(python_type): return python_val @@ -431,7 +433,13 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> t = FlyteSchemaTransformer() return t.to_python_value( FlyteContext.current_context(), - Literal(scalar=Scalar(schema=Schema(python_val.remote_path, t._get_schema_type(expected_python_type)))), + Literal( + scalar=Scalar( + schema=Schema( + cast(FlyteSchema, python_val).remote_path, t._get_schema_type(expected_python_type) + ) + ) + ), expected_python_type, ) elif issubclass(expected_python_type, FlyteFile): @@ -445,7 +453,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE ) ), - uri=python_val.path, + uri=cast(FlyteFile, python_val).path, ) ) ), @@ -462,7 +470,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART ) ), - uri=python_val.path, + uri=cast(FlyteDirectory, python_val).path, ) ) ), @@ -475,9 +483,11 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) -> scalar=Scalar( structured_dataset=StructuredDataset( metadata=StructuredDatasetMetadata( - structured_dataset_type=StructuredDatasetType(format=python_val.file_format) + structured_dataset_type=StructuredDatasetType( + format=cast(StructuredDataset, python_val).file_format + ) ), - uri=python_val.uri, + uri=cast(StructuredDataset, python_val).uri, ) ) ), @@ -516,7 +526,9 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: if isinstance(val, dict): ktype, vtype = DictTransformer.get_dict_types(t) # Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}}) - return {self._fix_val_int(ktype, k): self._fix_val_int(vtype, v) for k, v in val.items()} + return { + self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items() + } if dataclasses.is_dataclass(t): return self._fix_dataclass_int(t, val) # type: ignore @@ -557,7 +569,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: # calls to guess_python_type would result in a logically equivalent (but new) dataclass, which # TypeEngine.assert_type would not be happy about. @lru_cache(typed=True) - def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: schema_name = literal_type.metadata["$ref"].split("/")[-1] @@ -582,7 +594,7 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: struct = Struct() try: - struct.update(_MessageToDict(python_val)) + struct.update(_MessageToDict(cast(Message, python_val))) except Exception: raise TypeTransformerFailedError("Failed to convert to generic protobuf struct") return Literal(scalar=Scalar(generic=struct)) @@ -593,7 +605,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: pb_obj = expected_python_type() dictionary = _MessageToDict(lv.scalar.generic) - pb_obj = _ParseDict(dictionary, pb_obj) + pb_obj = _ParseDict(dictionary, pb_obj) # type: ignore return pb_obj def guess_python_type(self, literal_type: LiteralType) -> Type[T]: @@ -616,7 +628,7 @@ class TypeEngine(typing.Generic[T]): _REGISTRY: typing.Dict[type, TypeTransformer[T]] = {} _RESTRICTED_TYPES: typing.List[type] = [] - _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() + _DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore @classmethod def register( @@ -641,10 +653,10 @@ def register( def register_restricted_type( cls, name: str, - type: Type, + type: Type[T], ): cls._RESTRICTED_TYPES.append(type) - cls.register(RestrictedTypeTransformer(name, type)) + cls.register(RestrictedTypeTransformer(name, type)) # type: ignore @classmethod def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False): @@ -901,8 +913,8 @@ def get_sub_type(t: Type[T]) -> Type[T]: if get_origin(t) is Annotated: return ListTransformer.get_sub_type(get_args(t)[0]) - if t.__origin__ is list and hasattr(t, "__args__"): - return t.__args__[0] + if getattr(t, "__origin__") is list and hasattr(t, "__args__"): + return getattr(t, "__args__")[0] raise ValueError("Only generic univariate typing.List[T] type is supported.") @@ -924,7 +936,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore return Literal(collection=LiteralCollection(literals=lit_list)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[T]: + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals except AttributeError: @@ -933,10 +945,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: st = self.get_sub_type(expected_python_type) return [TypeEngine.to_python_value(ctx, x, st) for x in lits] - def guess_python_type(self, literal_type: LiteralType) -> Type[list]: + def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: - ct = TypeEngine.guess_python_type(literal_type.collection_type) - return typing.List[ct] + ct: Type = TypeEngine.guess_python_type(literal_type.collection_type) + return typing.List[ct] # type: ignore raise ValueError(f"List transformer cannot reverse {literal_type}") @@ -1049,7 +1061,9 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: t = get_args(t)[0] try: - trans = [(TypeEngine.get_transformer(x), x) for x in get_args(t)] + trans: typing.List[typing.Tuple[TypeTransformer, typing.Any]] = [ + (TypeEngine.get_transformer(x), x) for x in get_args(t) + ] # must go through TypeEngine.to_literal_type instead of trans.get_literal_type # to handle Annotated variants = [_add_tag_to_type(TypeEngine.to_literal_type(x), t.name) for (t, x) in trans] @@ -1066,7 +1080,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp res_type = None for t in get_args(python_type): try: - trans = TypeEngine.get_transformer(t) + trans: TypeTransformer[T] = TypeEngine.get_transformer(t) res = trans.to_literal(ctx, python_val, t, expected) res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) @@ -1099,7 +1113,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: res_tag = None for v in get_args(expected_python_type): try: - trans = TypeEngine.get_transformer(v) + trans: TypeTransformer[T] = TypeEngine.get_transformer(v) if union_tag is not None: if trans.name != union_tag: continue @@ -1138,7 +1152,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: def guess_python_type(self, literal_type: LiteralType) -> type: if literal_type.union_type is not None: - return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] + return typing.Union[tuple(TypeEngine.guess_python_type(v) for v in literal_type.union_type.variants)] # type: ignore raise ValueError(f"Union transformer cannot reverse {literal_type}") @@ -1185,7 +1199,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: if tp: if tp[0] == str: try: - sub_type = TypeEngine.to_literal_type(tp[1]) + sub_type = TypeEngine.to_literal_type(cast(type, tp[1])) return _type_models.LiteralType(map_value_type=sub_type) except Exception as e: raise ValueError(f"Type of Generic List type is not supported, {e}") @@ -1206,7 +1220,7 @@ def to_literal( raise ValueError("Flyte MapType expects all keys to be strings") # TODO: log a warning for Annotated objects that contain HashMethod k_type, v_type = self.get_dict_types(python_type) - lit_map[k] = TypeEngine.to_literal(ctx, v, v_type, expected.map_value_type) + lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) return Literal(map=LiteralMap(literals=lit_map)) def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: @@ -1222,7 +1236,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - py_map[k] = TypeEngine.to_python_value(ctx, v, tp[1]) + py_map[k] = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict @@ -1260,10 +1274,8 @@ def _blob_type(self) -> _core_types.BlobType: dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) - def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType( - blob=self._blob_type(), - ) + def get_literal_type(self, t: typing.TextIO) -> LiteralType: # type: ignore + return _type_models.LiteralType(blob=self._blob_type()) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -1334,7 +1346,9 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: raise TypeTransformerFailedError("Only EnumTypes with value of string are supported") return LiteralType(enum_type=_core_types.EnumType(values=values)) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal( + self, ctx: FlyteContext, python_val: enum.Enum, python_type: Type[T], expected: LiteralType + ) -> Literal: if type(python_val).__class__ != enum.EnumMeta: raise TypeTransformerFailedError("Expected an enum") if type(python_val.value) != str: @@ -1343,11 +1357,12 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: - return expected_python_type(lv.scalar.primitive.string_value) + return expected_python_type(lv.scalar.primitive.string_value) # type: ignore -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: - """Generate a model class based on the provided JSON Schema +def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema :param schema: dict representing valid JSON schema :param schema_name: dataclass name of return type """ @@ -1356,7 +1371,7 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore # Handle dataclass and dict elif property_type == "object": if property_val.get("$ref"): @@ -1364,13 +1379,13 @@ def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[datac attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore ) else: - attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) + attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore # Handle int, float, bool or str else: - attribute_list.append([property_key, _get_element_type(property_val)]) + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) @@ -1544,8 +1559,8 @@ def __init__( raise ValueError("Cannot instantiate LiteralsResolver without a map of Literals.") self._literals = literals self._variable_map = variable_map - self._native_values = {} - self._type_hints = {} + self._native_values: Dict[str, type] = {} + self._type_hints: Dict[str, type] = {} self._ctx = ctx def __str__(self) -> str: @@ -1598,7 +1613,7 @@ def __getitem__(self, key: str): return self.get(key) - def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: + def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: # type: ignore """ This will get the ``attr`` value from the Literal map, and invoke the TypeEngine to convert it into a Python native value. A Python type can optionally be supplied. If successful, the native value will be cached and @@ -1625,7 +1640,9 @@ def get(self, attr: str, as_type: Optional[typing.Type] = None) -> typing.Any: raise e else: ValueError("as_type argument not supplied and Variable map not specified in LiteralsResolver") - val = TypeEngine.to_python_value(self._ctx or FlyteContext.current_context(), self._literals[attr], as_type) + val = TypeEngine.to_python_value( + self._ctx or FlyteContext.current_context(), self._literals[attr], cast(Type, as_type) + ) self._native_values[attr] = val return val diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index ae8b89a109..ee2c841465 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -51,7 +51,7 @@ def _dnsify(value: str) -> str: def _get_container_definition( image: str, command: List[str], - args: List[str], + args: Optional[List[str]] = None, data_loading_config: Optional[task_models.DataLoadingConfig] = None, storage_request: Optional[str] = None, ephemeral_storage_request: Optional[str] = None, diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 6eac2e2a3c..f8ba257d7e 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -177,9 +177,9 @@ def __init__( self._workflow_metadata_defaults = workflow_metadata_defaults self._python_interface = python_interface self._interface = transform_interface_to_typed_interface(python_interface) - self._inputs = {} - self._unbound_inputs = set() - self._nodes = [] + self._inputs: Dict[str, Promise] = {} + self._unbound_inputs: set = set() + self._nodes: List[Node] = [] self._output_bindings: List[_literal_models.Binding] = [] self._docs = docs @@ -191,7 +191,9 @@ def __init__( ) else: if self._python_interface.docstring.short_description: - self._docs.short_description = self._python_interface.docstring.short_description + cast( + Documentation, self._docs + ).short_description = self._python_interface.docstring.short_description if self._python_interface.docstring.long_description: self._docs = Description(value=self._python_interface.docstring.long_description) @@ -211,11 +213,11 @@ def short_name(self) -> str: return extract_obj_name(self._name) @property - def workflow_metadata(self) -> Optional[WorkflowMetadata]: + def workflow_metadata(self) -> WorkflowMetadata: return self._workflow_metadata @property - def workflow_metadata_defaults(self): + def workflow_metadata_defaults(self) -> WorkflowMetadataDefaults: return self._workflow_metadata_defaults @property @@ -248,7 +250,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ @@ -412,7 +414,7 @@ def execute(self, **kwargs): raise FlyteValidationException(f"Workflow not ready, wf is currently {self}") # Create a map that holds the outputs of each node. - intermediate_node_outputs = {GLOBAL_START_NODE: {}} # type: Dict[Node, Dict[str, Promise]] + intermediate_node_outputs: Dict[Node, Dict[str, Promise]] = {GLOBAL_START_NODE: {}} # Start things off with the outputs of the global input node, i.e. the inputs to the workflow. # local_execute should've already ensured that all the values in kwargs are Promise objects @@ -509,7 +511,7 @@ def get_input_values(input_value): self._unbound_inputs.remove(input_value) return n # type: ignore - def add_workflow_input(self, input_name: str, python_type: Type) -> Interface: + def add_workflow_input(self, input_name: str, python_type: Type) -> Promise: """ Adds an input to the workflow. """ @@ -536,7 +538,8 @@ def add_workflow_output( f"If specifying a list or dict of Promises, you must specify the python_type type for {output_name}" f" starting with the container type (e.g. List[int]" ) - python_type = p.ref.node.flyte_entity.python_interface.outputs[p.var] + promise = cast(Promise, p) + python_type = promise.ref.node.flyte_entity.python_interface.outputs[promise.var] logger.debug(f"Inferring python type for wf output {output_name} from Promise provided {python_type}") flyte_type = TypeEngine.to_literal_type(python_type=python_type) @@ -589,8 +592,8 @@ class PythonFunctionWorkflow(WorkflowBase, ClassStorageTaskResolver): def __init__( self, workflow_function: Callable, - metadata: Optional[WorkflowMetadata], - default_metadata: Optional[WorkflowMetadataDefaults], + metadata: WorkflowMetadata, + default_metadata: WorkflowMetadataDefaults, docstring: Optional[Docstring] = None, docs: Optional[Documentation] = None, ): @@ -614,7 +617,7 @@ def __init__( def function(self): return self._workflow_function - def task_name(self, t: PythonAutoContainerTask) -> str: + def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore return f"{self.name}.{t.__module__}.{t.name}" def compile(self, **kwargs): @@ -763,10 +766,10 @@ def wrapper(fn): if _workflow_function: return wrapper(_workflow_function) else: - return wrapper + return wrapper # type: ignore -class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): +class ReferenceWorkflow(ReferenceEntity, PythonFunctionWorkflow): # type: ignore """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This object will not initiate a network call to Admin, which is why the user is asked to provide the expected interface. diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 4f06c3d3c6..e0a864e31e 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -628,7 +628,7 @@ def uri(self) -> str: return self._uri @property - def metadata(self) -> StructuredDatasetMetadata: + def metadata(self) -> Optional[StructuredDatasetMetadata]: return self._metadata def to_flyte_idl(self) -> _literals_pb2.StructuredDataset: diff --git a/flytekit/types/directory/__init__.py b/flytekit/types/directory/__init__.py index c2ab8fd438..87b494d0ae 100644 --- a/flytekit/types/directory/__init__.py +++ b/flytekit/types/directory/__init__.py @@ -28,7 +28,7 @@ TensorBoard. """ -tfrecords_dir = typing.TypeVar("tfrecord") +tfrecords_dir = typing.TypeVar("tfrecords_dir") TFRecordsDirectory = FlyteDirectory[tfrecords_dir] """ This type can be used to denote that the output is a folder that contains tensorflow record files. diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index afb59d58d0..7d576f9353 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -115,7 +115,12 @@ def t1(in1: FlyteDirectory["svg"]): field in the ``BlobType``. """ - def __init__(self, path: typing.Union[str, os.PathLike], downloader: typing.Callable = None, remote_directory=None): + def __init__( + self, + path: typing.Union[str, os.PathLike], + downloader: typing.Optional[typing.Callable] = None, + remote_directory: typing.Optional[str] = None, + ): """ :param path: The source path that users are expected to call open() on :param downloader: Optional function that can be passed that used to delay downloading of the actual fil diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 9fc55f76ce..6537f85cae 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -346,13 +346,13 @@ def to_python_value( return FlyteFile(uri) # The rest of the logic is only for FlyteFile types. - if not issubclass(expected_python_type, FlyteFile): + if not issubclass(expected_python_type, FlyteFile): # type: ignore raise TypeError(f"Neither os.PathLike nor FlyteFile specified {expected_python_type}") # This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't # make any sense. if not ctx.file_access.is_remote(uri): - return expected_python_type(uri) + return expected_python_type(uri) # type: ignore # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) diff --git a/flytekit/types/numpy/ndarray.py b/flytekit/types/numpy/ndarray.py index 38fedfacca..d766818bfd 100644 --- a/flytekit/types/numpy/ndarray.py +++ b/flytekit/types/numpy/ndarray.py @@ -77,7 +77,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return np.load( file=local_path, allow_pickle=metadata.get("allow_pickle", False), - mmap_mode=metadata.get("mmap_mode"), + mmap_mode=metadata.get("mmap_mode"), # type: ignore ) def guess_python_type(self, literal_type: LiteralType) -> typing.Type[np.ndarray]: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 8a8d832b58..c380bcc481 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -100,32 +100,38 @@ def write(self, *dfs, **kwargs): class LocalIOSchemaReader(SchemaReader[T]): - def __init__(self, from_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(from_path), cols, fmt) + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(from_path, cols, fmt) @abstractmethod def _read(self, *path: os.PathLike, **kwargs) -> T: pass def iter(self, **kwargs) -> typing.Generator[T, None, None]: - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - yield self._read(Path(entry.path), **kwargs) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + yield self._read(Path(typing.cast(os.DirEntry, entry).path), **kwargs) def all(self, **kwargs) -> T: files: typing.List[os.PathLike] = [] - with os.scandir(self._from_path) as it: + with os.scandir(self._from_path) as it: # type: ignore for entry in it: - if not entry.name.startswith(".") and entry.is_file(): - files.append(Path(entry.path)) + if ( + not typing.cast(os.DirEntry, entry).name.startswith(".") + and typing.cast(os.DirEntry, entry).is_file() + ): + files.append(Path(typing.cast(os.DirEntry, entry).path)) return self._read(*files, **kwargs) class LocalIOSchemaWriter(SchemaWriter[T]): - def __init__(self, to_local_path: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(str(to_local_path), cols, fmt) + def __init__(self, to_local_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(to_local_path, cols, fmt) @abstractmethod def _write(self, df: T, path: os.PathLike, **kwargs): @@ -176,7 +182,7 @@ def get_handler(cls, t: Type) -> SchemaHandler: @dataclass_json @dataclass class FlyteSchema(object): - remote_path: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. """ @@ -229,10 +235,10 @@ def format(cls) -> SchemaFormat: def __init__( self, - local_path: os.PathLike = None, - remote_path: os.PathLike = None, + local_path: typing.Optional[str] = None, + remote_path: typing.Optional[str] = None, supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, - downloader: typing.Callable[[str, os.PathLike], None] = None, + downloader: typing.Optional[typing.Callable] = None, ): if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") @@ -254,7 +260,7 @@ def __init__( self._downloader = downloader @property - def local_path(self) -> os.PathLike: + def local_path(self) -> str: return self._local_path @property @@ -262,7 +268,7 @@ def supported_mode(self) -> SchemaOpenMode: return self._supported_mode def open( - self, dataframe_fmt: type = pandas.DataFrame, override_mode: SchemaOpenMode = None + self, dataframe_fmt: type = pandas.DataFrame, override_mode: typing.Optional[SchemaOpenMode] = None ) -> typing.Union[SchemaReader, SchemaWriter]: """ Returns a reader or writer depending on the mode of the object when created. This mode can be @@ -290,13 +296,13 @@ def open( self._downloader(self.remote_path, self.local_path) self._downloaded = True if mode == SchemaOpenMode.WRITE: - return h.writer(typing.cast(str, self.local_path), self.columns(), self.format()) - return h.reader(typing.cast(str, self.local_path), self.columns(), self.format()) + return h.writer(self.local_path, self.columns(), self.format()) + return h.reader(self.local_path, self.columns(), self.format()) # Remote IO is handled. So we will just pass the remote reference to the object if mode == SchemaOpenMode.WRITE: - return h.writer(self.remote_path, self.columns(), self.format()) - return h.reader(self.remote_path, self.columns(), self.format()) + return h.writer(typing.cast(str, self.remote_path), self.columns(), self.format()) + return h.reader(typing.cast(str, self.remote_path), self.columns(), self.format()) def as_readonly(self) -> FlyteSchema: if self._supported_mode == SchemaOpenMode.READ: @@ -304,7 +310,7 @@ def as_readonly(self) -> FlyteSchema: s = FlyteSchema.__class_getitem__(self.columns(), self.format())( local_path=self.local_path, # Dummy path is ok, as we will assume data is already downloaded and will not download again - remote_path=self.remote_path if self.remote_path else "", + remote_path=typing.cast(str, self.remote_path) if self.remote_path else "", supported_mode=SchemaOpenMode.READ, ) s._downloaded = True diff --git a/flytekit/types/schema/types_pandas.py b/flytekit/types/schema/types_pandas.py index e4c6078e94..ca6cab8030 100644 --- a/flytekit/types/schema/types_pandas.py +++ b/flytekit/types/schema/types_pandas.py @@ -17,7 +17,9 @@ class ParquetIO(object): def _read(self, chunk: os.PathLike, columns: typing.Optional[typing.List[str]], **kwargs) -> pandas.DataFrame: return pandas.read_parquet(chunk, columns=columns, engine=self.PARQUET_ENGINE, **kwargs) - def read(self, *files: os.PathLike, columns: typing.List[str] = None, **kwargs) -> pandas.DataFrame: + def read( + self, *files: os.PathLike, columns: typing.Optional[typing.List[str]] = None, **kwargs + ) -> pandas.DataFrame: frames = [self._read(chunk=f, columns=columns, **kwargs) for f in files if os.path.getsize(f) > 0] if len(frames) == 1: return frames[0] @@ -56,7 +58,7 @@ def write( class PandasSchemaReader(LocalIOSchemaReader[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() @@ -65,7 +67,7 @@ def _read(self, *path: os.PathLike, **kwargs) -> pandas.DataFrame: class PandasSchemaWriter(LocalIOSchemaWriter[pandas.DataFrame]): - def __init__(self, local_dir: os.PathLike, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + def __init__(self, local_dir: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): super().__init__(local_dir, cols, fmt) self._parquet_engine = ParquetIO() diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 0e4649203a..90755c8cc5 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import os import types import typing from abc import ABC, abstractmethod @@ -45,7 +44,7 @@ class StructuredDataset(object): class (that is just a model, a Python class representation of the protobuf). """ - uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) + uri: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String())) @classmethod @@ -59,7 +58,7 @@ def column_names(cls) -> typing.List[str]: def __init__( self, dataframe: typing.Optional[typing.Any] = None, - uri: Optional[str, os.PathLike] = None, + uri: typing.Optional[str] = None, metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, **kwargs, ): @@ -74,10 +73,10 @@ def __init__( # This is not for users to set, the transformer will set this. self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call - self._dataframe_type: Optional[Type[DF]] = None + self._dataframe_type: Optional[DF] = None # type: ignore @property - def dataframe(self) -> Optional[Type[DF]]: + def dataframe(self) -> Optional[DF]: return self._dataframe @property @@ -92,7 +91,7 @@ def open(self, dataframe_type: Type[DF]): self._dataframe_type = dataframe_type return self - def all(self) -> DF: + def all(self) -> DF: # type: ignore if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() @@ -255,7 +254,7 @@ def decode( ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, - ) -> Union[DF, Generator[DF, None, None]]: + ) -> Union[DF, typing.Iterator[DF]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. @@ -617,7 +616,7 @@ def encode( # least as good as the type of the interface. if sd_model.metadata is None: sd_model._metadata = StructuredDatasetMetadata(structured_literal_type) - if sd_model.metadata.structured_dataset_type is None: + if sd_model.metadata and sd_model.metadata.structured_dataset_type is None: sd_model.metadata._structured_dataset_type = structured_literal_type # Always set the format here to the format of the handler. # Note that this will always be the same as the incoming format except for when the fallback handler @@ -747,7 +746,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ # Here we only render column information by default instead of opening the structured dataset. col = typing.cast(StructuredDataset, python_val).columns() df = pd.DataFrame(col, ["column type"]) - return df.to_html() + return df.to_html() # type: ignore else: df = python_val @@ -783,10 +782,10 @@ def iter_as( sd: literals.StructuredDataset, df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, - ) -> Generator[DF, None, None]: + ) -> typing.Iterator[DF]: protocol = protocol_prefix(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] - result = decoder.decode(ctx, sd, updated_metadata) + result: Union[DF, typing.Iterator[DF]] = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result @@ -801,7 +800,7 @@ def _get_dataset_column_literal_type(self, t: Type) -> type_models.LiteralType: raise AssertionError(f"type {t} is currently not supported by StructuredDataset") def _convert_ordered_dict_of_columns_to_list( - self, column_map: typing.OrderedDict[str, Type] + self, column_map: typing.Optional[typing.OrderedDict[str, Type]] ) -> typing.List[StructuredDatasetType.DatasetColumn]: converted_cols: typing.List[StructuredDatasetType.DatasetColumn] = [] if column_map is None or len(column_map) == 0: @@ -812,10 +811,13 @@ def _convert_ordered_dict_of_columns_to_list( return converted_cols def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> StructuredDatasetType: - original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) + original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore # Get the column information - converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map) + converted_cols: typing.List[ + StructuredDatasetType.DatasetColumn + ] = self._convert_ordered_dict_of_columns_to_list(column_map) + return StructuredDatasetType( columns=converted_cols, format=storage_format, diff --git a/tests/flytekit/unit/core/flyte_functools/decorator_source.py b/tests/flytekit/unit/core/flyte_functools/decorator_source.py index 9c92364649..5790d5d358 100644 --- a/tests/flytekit/unit/core/flyte_functools/decorator_source.py +++ b/tests/flytekit/unit/core/flyte_functools/decorator_source.py @@ -1,10 +1,11 @@ """Script used for testing local execution of functool.wraps-wrapped tasks for stacked decorators""" - +import functools +import typing from functools import wraps from typing import List -def task_setup(function: callable = None, *, integration_requests: List = None) -> None: +def task_setup(function: typing.Callable, *, integration_requests: typing.Optional[List] = None) -> typing.Callable: integration_requests = integration_requests or [] @wraps(function) diff --git a/tests/flytekit/unit/core/flyte_functools/nested_function.py b/tests/flytekit/unit/core/flyte_functools/nested_function.py index 6a3ccfd9e1..98a39e497a 100644 --- a/tests/flytekit/unit/core/flyte_functools/nested_function.py +++ b/tests/flytekit/unit/core/flyte_functools/nested_function.py @@ -32,4 +32,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py index a51a283be5..3278af1bb0 100644 --- a/tests/flytekit/unit/core/flyte_functools/simple_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/simple_decorator.py @@ -38,4 +38,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py index 07c46cd46a..dd445a6fb3 100644 --- a/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py +++ b/tests/flytekit/unit/core/flyte_functools/stacked_decorators.py @@ -48,4 +48,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py index 9f7e6599c6..6e22ca9840 100644 --- a/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py +++ b/tests/flytekit/unit/core/flyte_functools/unwrapped_decorator.py @@ -26,4 +26,4 @@ def my_workflow(x: int) -> int: if __name__ == "__main__": - print(my_workflow(x=int(os.getenv("SCRIPT_INPUT")))) + print(my_workflow(x=int(os.getenv("SCRIPT_INPUT", 0)))) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 3963c77c8d..8eb105777e 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -35,14 +35,12 @@ def my_wf(a: int, b: str) -> (int, str, str): def test_single_named_output_subwf(): - nt = NamedTuple("SubWfOutput", sub_int=int) + nt = NamedTuple("SubWfOutput", [("sub_int", int)]) @task def t1(a: int) -> nt: a = a + 2 - return nt( - a, - ) # returns a named tuple + return nt(a) @task def t2(a: int, b: int) -> nt: diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index ca234c743b..3ab0026cf3 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -299,7 +299,7 @@ def branching(x: int): def test_subworkflow_condition_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int, c=str) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int), ("c", str)]) @task def t() -> nt: @@ -318,13 +318,11 @@ def branching(x: int) -> nt: def test_subworkflow_condition_single_named_tuple(): - nt = typing.NamedTuple("SampleNamedTuple", b=int) + nt = typing.NamedTuple("SampleNamedTuple", [("b", int)]) @task def t() -> nt: - return nt( - 5, - ) + return nt(5) @workflow def wf1() -> nt: diff --git a/tests/flytekit/unit/core/test_gate.py b/tests/flytekit/unit/core/test_gate.py index c92e1c9e19..bb245ad594 100644 --- a/tests/flytekit/unit/core/test_gate.py +++ b/tests/flytekit/unit/core/test_gate.py @@ -219,7 +219,7 @@ def wf_dyn(a: int) -> typing.Tuple[int, int]: def test_subwf(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def nt1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index db4b32f6a9..ead5358316 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -67,15 +67,13 @@ def t2(): assert len(wf_spec.template.interface.outputs) == 1 # docs_equivalent_start - nt = typing.NamedTuple("wf_output", from_n0t1=str) + nt = typing.NamedTuple("wf_output", [("from_n0t1", str)]) @workflow def my_workflow(in1: str) -> nt: x = t1(a=in1) t2() - return nt( - x, - ) + return nt(x) # docs_equivalent_end diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index db05de0ddb..26b43f2ef5 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -102,7 +102,7 @@ def x(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): return ("hello world", 5) def y(a: int, b: str) -> nt1: - return nt1("hello world", 5) + return nt1("hello world", 5) # type: ignore result = transform_variable_map(extract_return_annotation(typing.get_type_hints(x).get("return", None))) assert result["x_str"].type.simple == 3 diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index ffaff8daad..3addd13e42 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -292,7 +292,7 @@ def wf(a: int, c: str) -> (int, str): def test_lp_all_parameters(): - nt = typing.NamedTuple("OutputsBC", t1_int_output=int, c=str) + nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]) @task def t1(a: int) -> nt: diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 2813563fb9..0858a08007 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -103,14 +103,12 @@ def test_more_normal_task(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore @task def t1_nt(a: int) -> nt: # This one returns an instance of the named tuple. - return nt(f"{a + 2}") + return nt(f"{a + 2}") # type: ignore @task def t2(a: typing.List[str]) -> str: @@ -133,9 +131,7 @@ def test_reserved_keyword(): @task def t1(a: int) -> nt: # This one returns a regular tuple - return nt( - f"{a + 2}", - ) + return nt(f"{a + 2}") # type: ignore # Test that you can't name an output "outputs" with pytest.raises(FlyteAssertion): diff --git a/tests/flytekit/unit/core/test_python_function_task.py b/tests/flytekit/unit/core/test_python_function_task.py index 02e04a302f..7bbdd23a21 100644 --- a/tests/flytekit/unit/core/test_python_function_task.py +++ b/tests/flytekit/unit/core/test_python_function_task.py @@ -145,7 +145,7 @@ def test_pod_template(): pod_template_name="A", ) def func_with_pod_template(i: str): - print(i + 3) + print(i + "a") default_image = Image(name="default", fqn="docker.io/xyz", tag="some-git-hash") default_image_config = ImageConfig(default_image=default_image) diff --git a/tests/flytekit/unit/core/test_realworld_examples.py b/tests/flytekit/unit/core/test_realworld_examples.py index 83e859c1da..779ba3334c 100644 --- a/tests/flytekit/unit/core/test_realworld_examples.py +++ b/tests/flytekit/unit/core/test_realworld_examples.py @@ -126,7 +126,7 @@ def fit(x: FlyteSchema[FEATURE_COLUMNS], y: FlyteSchema[CLASSES_COLUMNS], hyperp fname = "model.joblib.dat" with open(fname, "w") as f: f.write("Some binary data") - return nt(model=fname) + return nt(model=fname) # type: ignore @task(cache_version="1.0", cache=True, limits=Resources(mem="200Mi")) def predict(x: FlyteSchema[FEATURE_COLUMNS], model_ser: FlyteFile[MODELSER_JOBLIB]) -> FlyteSchema[CLASSES_COLUMNS]: diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index df6e093b55..7486422fd9 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -160,7 +160,7 @@ def inner_test(ref_mock): @task def t1(a: int) -> nt1: a = a + 2 - return nt1(a, "world-" + str(a)) + return nt1(a, "world-" + str(a)) # type: ignore @workflow def wf2(a: int): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index a96a94843b..8deb406fb2 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -406,17 +406,17 @@ def wf() -> typing.NamedTuple("OP", a=str): def test_named_outputs_nested(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello().greet, say_hello().greet) x, y = my_wf() @@ -425,19 +425,19 @@ def my_wf() -> wf_outputs: def test_named_outputs_nested_fail(): - nm = typing.NamedTuple("OP", greet=str) + nm = typing.NamedTuple("OP", [("greet", str)]) @task def say_hello() -> nm: return nm("hello world") - wf_outputs = typing.NamedTuple("OP2", greet1=str, greet2=str) + wf_outputs = typing.NamedTuple("OP2", [("greet1", str), ("greet2", str)]) with pytest.raises(AssertionError): # this should fail because say_hello returns a tuple, but we do not de-reference it @workflow def my_wf() -> wf_outputs: - # Note only Namedtuples can be created like this + # Note only Namedtuple can be created like this return wf_outputs(say_hello(), say_hello()) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index eb38a8d80b..f0fce6bc17 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1471,21 +1471,21 @@ def test_multiple_annotations(): TypeEngine.to_literal_type(t) -TestSchema = FlyteSchema[kwtypes(some_str=str)] +TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore @dataclass_json @dataclass class InnerResult: number: int - schema: TestSchema + schema: TestSchema # type: ignore @dataclass_json @dataclass class Result: result: InnerResult - schema: TestSchema + schema: TestSchema # type: ignore def test_schema_in_dataclass(): diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index b6d2d77ae5..f41a05ea32 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -184,13 +184,11 @@ def my_wf(a: int, b: str) -> (int, str): assert my_wf._output_bindings[0].var == "o0" assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output" - nt = typing.NamedTuple("SingleNT", t1_int_output=float) + nt = typing.NamedTuple("SingleNT", [("t1_int_output", float)]) @task def t3(a: int) -> nt: - return nt( - a + 2, - ) + return nt(a + 2) assert t3.python_interface.output_tuple_name == "SingleNT" assert t3.interface.outputs["t1_int_output"] is not None @@ -882,7 +880,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_subwf(a: int) -> (str, str): + def my_subwf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -1406,7 +1404,7 @@ def t2(a: str, b: str) -> str: return b + a @workflow - def my_wf(a: int, b: str) -> (str, typing.List[str]): + def my_wf(a: int, b: str) -> typing.Tuple[str, typing.List[str]]: @dynamic def my_subwf(a: int) -> typing.List[str]: s = [] @@ -1446,7 +1444,7 @@ def t1() -> str: return "Hello" @workflow - def wf() -> typing.NamedTuple("OP", a=str, b=str): + def wf() -> typing.NamedTuple("OP", [("a", str), ("b", str)]): # type: ignore return t1(), t1() assert wf() == ("Hello", "Hello") diff --git a/tests/flytekit/unit/core/test_typing_annotation.py b/tests/flytekit/unit/core/test_typing_annotation.py index 9c2d09c145..2937d9f978 100644 --- a/tests/flytekit/unit/core/test_typing_annotation.py +++ b/tests/flytekit/unit/core/test_typing_annotation.py @@ -18,7 +18,7 @@ env=None, image_config=ImageConfig(default_image=default_img, images=[default_img]), ) -entity_mapping = OrderedDict() +entity_mapping: OrderedDict = OrderedDict() @task diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 46389daed2..23b9d0631e 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -44,12 +44,12 @@ def test_default_metadata_values(): def test_workflow_values(): @task - def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + def t1(a: int) -> typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)]): a = a + 2 return a, "world-" + str(a) @workflow(interruptible=True, failure_policy=WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) - def wf(a: int) -> (str, str): + def wf(a: int) -> typing.Tuple[str, str]: x, y = t1(a=a) u, v = t1(a=x) return y, v @@ -94,7 +94,7 @@ def list_output_wf() -> typing.List[int]: def test_sub_wf_single_named_tuple(): - nt = typing.NamedTuple("SingleNamedOutput", named1=int) + nt = typing.NamedTuple("SingleNamedOutput", [("named1", int)]) @task def t1(a: int) -> nt: @@ -115,7 +115,7 @@ def wf(b: int) -> nt: def test_sub_wf_multi_named_tuple(): - nt = typing.NamedTuple("Multi", named1=int, named2=int) + nt = typing.NamedTuple("Multi", [("named1", int), ("named2", int)]) @task def t1(a: int) -> nt: @@ -153,7 +153,7 @@ def no_outputs_wf(): with pytest.raises(AssertionError): @workflow - def one_output_wf() -> int: # noqa + def one_output_wf() -> int: # type: ignore t1(a=3) @@ -309,10 +309,10 @@ def sd_to_schema_wf() -> pd.DataFrame: @workflow -def schema_to_sd_wf() -> (pd.DataFrame, pd.DataFrame): +def schema_to_sd_wf() -> typing.Tuple[pd.DataFrame, pd.DataFrame]: # schema -> StructuredDataset df = t4() - return t2(df=df), t5(sd=df) + return t2(df=df), t5(sd=df) # type: ignore def test_structured_dataset_wf():