diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 28febb3876..da0be3518e 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -47,11 +47,28 @@ jobs: run: | make -C plugins build_all_plugins make -C plugins publish_all_plugins - # Added sleep because PYPI take some time in publish - - name: Sleep for 180 seconds - uses: jakejarvis/wait-action@master - with: - time: '180s' + - name: Sleep until pypi is available + id: pypiwait + run: | + # from refs/tags/v1.2.3 get 1.2.3 and make sure it's not an empty string + VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') + if [ -z "$VERSION" ] + then + echo "No tagged version found, exiting" + exit 1 + fi + LINK="https://pypi.org/project/flytekit/${VERSION}" + for i in {1..60}; do + if curl -L -I -s -f ${LINK} >/dev/null; then + echo "Found pypi" + exit 0 + else + echo "Did not find - Retrying in 10 seconds..." + sleep 10 + fi + done + exit 1 + shell: bash outputs: version: ${{ steps.bump.outputs.version }} diff --git a/Dockerfile.dev b/Dockerfile.dev index f6baf63896..b7c5104bbc 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -12,27 +12,21 @@ MAINTAINER Flyte Team LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit WORKDIR /root -ENV PYTHONPATH /root ARG VERSION -ARG DOCKER_IMAGE RUN apt-get update && apt-get install build-essential vim -y -COPY . /code/flytekit -WORKDIR /code/flytekit +COPY . /flytekit # Pod tasks should be exposed in the default image -RUN pip install -e . -RUN pip install -e plugins/flytekit-k8s-pod -RUN pip install -e plugins/flytekit-deck-standard +RUN pip install -e /flytekit +RUN pip install -e /flytekit/plugins/flytekit-k8s-pod +RUN pip install -e /flytekit/plugins/flytekit-deck-standard RUN pip install scikit-learn -ENV PYTHONPATH "/code/flytekit:/code/flytekit/plugins/flytekit-k8s-pod:/code/flytekit/plugins/flytekit-deck-standard:" +ENV PYTHONPATH "/flytekit:/flytekit/plugins/flytekit-k8s-pod:/flytekit/plugins/flytekit-deck-standard:" -WORKDIR /root RUN useradd -u 1000 flytekit RUN chown flytekit: /root USER flytekit - -ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index ca7a6cf20d..a9b7c313f0 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -10,7 +10,6 @@ import click as _click from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit import PythonFunctionTask from flytekit.configuration import ( SERIALIZED_CONTEXT_ENV_VAR, FastSerializationSettings, @@ -23,7 +22,7 @@ from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapTaskResolver from flytekit.core.promise import VoidPromise from flytekit.exceptions import scopes as _scoped_exceptions from flytekit.exceptions import scopes as _scopes @@ -391,12 +390,8 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: - resolver_obj = load_object_from_module(resolver) - # Use the resolver to load the actual task object - _task_def = resolver_obj.load_task(loader_args=resolver_args) - if not isinstance(_task_def, PythonFunctionTask): - raise Exception("Map tasks cannot be run with instance tasks.") - map_task = MapPythonTask(_task_def, max_concurrency) + mtr = MapTaskResolver() + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) task_index = _compute_array_job_index() output_prefix = os.path.join(output_prefix, str(task_index)) diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 136831c0bc..c45ec3f150 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -81,7 +81,7 @@ def convert( raise ValueError( f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" ) - return Directory(dir_path=value, local_file=files[0].resolve()) + return Directory(dir_path=str(p), local_file=files[0].resolve()) raise click.BadParameter(f"parameter should be a valid directory path, {value}") diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 8fb73ebd8c..ea36689874 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -107,16 +107,16 @@ def data_config(self) -> DataConfig: return self._data_config def get_filesystem( - self, protocol: typing.Optional[str] = None, anonymous: bool = False + self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs ) -> typing.Optional[fsspec.AbstractFileSystem]: if not protocol: return self._default_remote - kwargs = {} # type: typing.Dict[str, typing.Any] if protocol == "file": - kwargs = {"auto_mkdir": True} + kwargs["auto_mkdir"] = True elif protocol == "s3": - kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) - return fsspec.filesystem(protocol, **kwargs) # type: ignore + s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + s3kwargs.update(kwargs) + return fsspec.filesystem(protocol, **s3kwargs) # type: ignore elif protocol == "gs": if anonymous: kwargs["token"] = _ANON @@ -128,9 +128,9 @@ def get_filesystem( return fsspec.filesystem(protocol, **kwargs) # type: ignore - def get_filesystem_for_path(self, path: str = "", anonymous: bool = False) -> fsspec.AbstractFileSystem: + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) - return self.get_filesystem(protocol, anonymous=anonymous) + return self.get_filesystem(protocol, anonymous=anonymous, **kwargs) @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 3c24e65db2..eae7a8e0cf 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -21,6 +21,28 @@ T = typing.TypeVar("T") +def repr_kv(k: str, v: Union[Type, Tuple[Type, Any]]) -> str: + if isinstance(v, tuple): + if v[1]: + return f"{k}: {v[0]}={v[1]}" + return f"{k}: {v[0]}" + return f"{k}: {v}" + + +def repr_type_signature(io: Union[Dict[str, Tuple[Type, Any]], Dict[str, Type]]) -> str: + """ + Converts an inputs and outputs to a type signature + """ + s = "(" + i = 0 + for k, v in io.items(): + if i > 0: + s += ", " + s += repr_kv(k, v) + i = i + 1 + return s + ")" + + class Interface(object): """ A Python native interface object, like inspect.signature but simpler. @@ -57,7 +79,9 @@ def __init__( variables = [k for k in outputs.keys()] # TODO: This class is a duplicate of the one in create_task_outputs. Over time, we should move to this one. - class Output(collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables)): # type: ignore + class Output( # type: ignore + collections.namedtuple(output_tuple_name or "DefaultNamedTupleOutput", variables) # type: ignore + ): # type: ignore """ This class can be used in two different places. For multivariate-return entities this class is used to rewrap the outputs so that our with_overrides function can work. @@ -167,6 +191,12 @@ def with_outputs(self, extra_outputs: Dict[str, Type]) -> Interface: new_outputs[k] = v return Interface(self._inputs, new_outputs) + def __str__(self): + return f"{repr_type_signature(self._inputs)} -> {repr_type_signature(self._outputs)}" + + def __repr__(self): + return str(self) + def transform_inputs_to_parameters( ctx: context_manager.FlyteContext, interface: Interface @@ -220,7 +250,7 @@ def transform_interface_to_typed_interface( return _interface_models.TypedInterface(inputs_map, outputs_map) -def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: +def transform_types_to_list_of_type(m: Dict[str, type], bound_inputs: typing.Set[str]) -> Dict[str, type]: """ Converts a given variables to be collections of their type. This is useful for array jobs / map style code. It will create a collection of types even if any one these types is not a collection type @@ -230,6 +260,10 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: all_types_are_collection = True for k, v in m.items(): + if k in bound_inputs: + # Skip the inputs that are bound. If they are bound, it does not matter if they are collection or + # singletons + continue v_type = type(v) if v_type != typing.List and v_type != list: all_types_are_collection = False @@ -240,17 +274,22 @@ def transform_types_to_list_of_type(m: Dict[str, type]) -> Dict[str, type]: om = {} for k, v in m.items(): - om[k] = typing.List[v] # type: ignore + if k in bound_inputs: + om[k] = v + else: + om[k] = typing.List[v] # type: ignore return om # type: ignore -def transform_interface_to_list_interface(interface: Interface) -> Interface: +def transform_interface_to_list_interface(interface: Interface, bound_inputs: typing.Set[str]) -> Interface: """ Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map like functions + :param interface: Interface to be upgraded toa list interface + :param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars. """ - map_inputs = transform_types_to_list_of_type(interface.inputs) - map_outputs = transform_types_to_list_of_type(interface.outputs) + map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs) + map_outputs = transform_types_to_list_of_type(interface.outputs, set()) return Interface(inputs=map_inputs, outputs=map_outputs) @@ -288,7 +327,6 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc For now the fancy object, maybe in the future a dumb object. """ - type_hints = get_type_hints(fn, include_extras=True) signature = inspect.signature(fn) return_annotation = type_hints.get("return", None) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 48d0f0b335..83b2542fe3 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -2,71 +2,87 @@ Flytekit map tasks specify how to run a single task across a list of inputs. Map tasks themselves are constructed with a reference task as well as run-time parameters that limit execution concurrency and failure tolerations. """ - +import functools +import hashlib +import logging import os import typing from contextlib import contextmanager -from itertools import count -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from flytekit.configuration import SerializationSettings from flytekit.core import tracker -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, Task, TaskResolverMixin from flytekit.core.constants import SdkTaskType from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.interface import transform_interface_to_list_interface from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.tracker import TrackedInstance from flytekit.exceptions import scopes as exception_scopes from flytekit.models.array_job import ArrayJob from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql +from flytekit.tools.module_loader import load_object_from_module class MapPythonTask(PythonTask): """ A MapPythonTask defines a :py:class:`flytekit.PythonTask` which specifies how to run an inner :py:class:`flytekit.PythonFunctionTask` across a range of inputs in parallel. - TODO: support lambda functions """ - # To support multiple map tasks declared around identical python function tasks, we keep a global count of - # MapPythonTask instances to uniquely differentiate map task names for each declared instance. - _ids = count(0) - def __init__( self, - python_function_task: PythonFunctionTask, + python_function_task: typing.Union[PythonFunctionTask, functools.partial], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, **kwargs, ): """ + Wrapper that creates a MapPythonTask + :param python_function_task: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given - batch size + batch size :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete - successfully before terminating this task and marking it successful. + successfully before terminating this task and marking it successful + :param bound_inputs: List[str] specifies a list of variable names within the interface of python_function_task, + that are already bound and should not be considered as list inputs, but scalar values. This is mostly + useful at runtime and is passed in by MapTaskResolver. This field is not required when a `partial` method + is specified. The bound_vars will be auto-deduced from the `partial.keywords`. """ - if len(python_function_task.python_interface.inputs.keys()) > 1: - raise ValueError("Map tasks only accept python function tasks with 0 or 1 inputs") + self._partial = None + if isinstance(python_function_task, functools.partial): + self._partial = python_function_task + actual_task = self._partial.func + else: + actual_task = python_function_task + + if not isinstance(actual_task, PythonFunctionTask): + raise ValueError("Map tasks can only compose of Python Functon Tasks currently") - if len(python_function_task.python_interface.outputs.keys()) > 1: + if len(actual_task.python_interface.outputs.keys()) > 1: raise ValueError("Map tasks only accept python function tasks with 0 or 1 outputs") - collection_interface = transform_interface_to_list_interface(python_function_task.python_interface) - instance = next(self._ids) - _, mod, f, _ = tracker.extract_task_module(python_function_task.task_function) - name = f"{mod}.mapper_{f}_{instance}" - - self._cmd_prefix = None - self._run_task = python_function_task - self._max_concurrency = concurrency - self._min_success_ratio = min_success_ratio - self._array_task_interface = python_function_task.python_interface - if "metadata" not in kwargs and python_function_task.metadata: - kwargs["metadata"] = python_function_task.metadata - if "security_ctx" not in kwargs and python_function_task.security_context: - kwargs["security_ctx"] = python_function_task.security_context + self._bound_inputs: typing.Set[str] = set(bound_inputs) if bound_inputs else set() + if self._partial: + self._bound_inputs = set(self._partial.keywords.keys()) + + collection_interface = transform_interface_to_list_interface(actual_task.python_interface, self._bound_inputs) + self._run_task: PythonFunctionTask = actual_task + _, mod, f, _ = tracker.extract_task_module(actual_task.task_function) + h = hashlib.md5(collection_interface.__str__().encode("utf-8")).hexdigest() + name = f"{mod}.map_{f}_{h}" + + self._cmd_prefix: typing.Optional[typing.List[str]] = None + self._max_concurrency: typing.Optional[int] = concurrency + self._min_success_ratio: typing.Optional[float] = min_success_ratio + self._array_task_interface = actual_task.python_interface + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context super().__init__( name=name, interface=collection_interface, @@ -76,7 +92,15 @@ def __init__( **kwargs, ) + @property + def bound_inputs(self) -> Set[str]: + return self._bound_inputs + def get_command(self, settings: SerializationSettings) -> List[str]: + """ + TODO ADD bound variables to the resolver. Maybe we need a different resolver? + """ + mt = MapTaskResolver() container_args = [ "pyflyte-map-execute", "--inputs", @@ -90,9 +114,9 @@ def get_command(self, settings: SerializationSettings) -> List[str]: "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - self._run_task.task_resolver.location, + mt.name(), "--", - *self._run_task.task_resolver.loader_args(settings, self._run_task), + *mt.loader_args(settings, self), ] if self._cmd_prefix: @@ -100,7 +124,7 @@ def get_command(self, settings: SerializationSettings) -> List[str]: return container_args def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): - self._cmd_prefix = cmd # type: ignore + self._cmd_prefix = cmd @contextmanager def prepare_target(self): @@ -135,6 +159,18 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] def run_task(self) -> PythonFunctionTask: return self._run_task + def __call__(self, *args, **kwargs): + """ + This call method modifies the kwargs and adds kwargs from partial. + This is mostly done in the local_execute and compilation only. + At runtime, the map_task is created with all the inputs filled in. to support this, we have modified + the map_task interface in the constructor. + """ + if self._partial: + """If partial exists, then mix-in all partial values""" + kwargs = {**self._partial.keywords, **kwargs} + return super().__call__(*args, **kwargs) + def execute(self, **kwargs) -> Any: ctx = FlyteContextManager.current_context() if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: @@ -191,7 +227,11 @@ def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: task_index = self._compute_array_job_index() map_task_inputs = {} for k in self.interface.inputs.keys(): - map_task_inputs[k] = kwargs[k][task_index] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + map_task_inputs[k] = v[task_index] + else: + map_task_inputs[k] = v return exception_scopes.user_entry_point(self._run_task.execute)(**map_task_inputs) def _raw_execute(self, **kwargs) -> Any: @@ -213,7 +253,11 @@ def _raw_execute(self, **kwargs) -> Any: for i in range(len(kwargs[any_input_key])): single_instance_inputs = {} for k in self.interface.inputs.keys(): - single_instance_inputs[k] = kwargs[k][i] + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] o = exception_scopes.user_entry_point(self._run_task.execute)(**single_instance_inputs) if outputs_expected: outputs.append(o) @@ -221,7 +265,12 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): +def map_task( + task_function: typing.Union[PythonFunctionTask, functools.partial], + concurrency: int = 0, + min_success_ratio: float = 1.0, + **kwargs, +): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -267,8 +316,63 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_succes successfully before terminating this task and marking it successful. """ - if not isinstance(task_function, PythonFunctionTask): - raise ValueError( - f"Only Flyte python task types are supported in map tasks currently, received {type(task_function)}" - ) return MapPythonTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + + +class MapTaskResolver(TrackedInstance, TaskResolverMixin): + """ + Special resolver that is used for MapTasks. + This exists because it is possible that MapTasks are created using nested "partial" subtasks. + When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation, + simply converts every input into a list/collection input. + + For example: + interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str] + + But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that + the interface is not simply interpolated, but only the unbound inputs are interpolated. + + .. code-block:: python + + def foo((i: int, j: str) -> str: + ... + + mt = map_task(functools.partial(foo, j=10)) + + print(mt.interface) + + output: + + (i: List[int], j: str) -> List[str] + + But, at runtime this information is lost. To reconstruct this, we use MapTaskResolver that records the "bound vars" + and then at runtime reconstructs the interface with this knowledge + """ + + def name(self) -> str: + return "MapTaskResolver" + + def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> MapPythonTask: + """ + Loader args should be of the form + vars "var1,var2,.." resolver "resolver" [resolver_args] + """ + _, bound_vars, _, resolver, *resolver_args = loader_args + logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}") + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + bound_inputs = set(bound_vars.split(",")) + return MapPythonTask(python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs) + + def loader_args(self, settings: SerializationSettings, t: MapPythonTask) -> List[str]: # type:ignore + return [ + "vars", + f'{",".join(t.bound_inputs)}', + "resolver", + t.run_task.task_resolver.location, + *t.run_task.task_resolver.loader_args(settings, t.run_task), + ] + + def get_all_tasks(self) -> List[Task]: + raise NotImplementedError("MapTask resolver cannot return every instance of the map task") diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f21e93a774..306c4116ad 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -129,7 +129,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: f"Conversion to python value expected type {expected_python_type} from literal not implemented" ) - @abstractmethod def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T]) -> str: """ Converts any python val (dataframe, int, float) to a html string, and it will be wrapped in the HTML div diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 7d576f9353..f4f23eb72f 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -2,20 +2,25 @@ import os import pathlib +import random import typing from dataclasses import dataclass, field from pathlib import Path +from typing import Any, Generator, Tuple +from uuid import UUID +import fsspec from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType -from flytekit.types.file import FileExt +from flytekit.types.file import FileExt, FlyteFile T = typing.TypeVar("T") PathType = typing.Union[str, os.PathLike] @@ -148,6 +153,18 @@ def __fspath__(self): def extension(cls) -> str: return "" + @classmethod + def new_remote(cls) -> FlyteDirectory: + """ + Create a new FlyteDirectory object using the currently configured default remote in the context (i.e. + the raw_output_prefix configured in the current FileAccessProvider object in the context). + This is used if you explicitly have a folder somewhere that you want to create files under. + If you want to write a whole folder, you can let your task return a FlyteDirectory object, + and let flytekit handle the uploading. + """ + d = FlyteContext.current_context().file_access.get_random_remote_directory() + return FlyteDirectory(path=d) + def __class_getitem__(cls, item: typing.Union[typing.Type, str]) -> typing.Type[FlyteDirectory]: if item is None: return cls @@ -176,6 +193,12 @@ def downloaded(self) -> bool: def remote_directory(self) -> typing.Optional[str]: return self._remote_directory + @property + def sep(self) -> str: + if os.name == "nt" and get_protocol(self.path or self.remote_source or self.remote_directory) == "file": + return "\\" + return "/" + @property def remote_source(self) -> str: """ @@ -184,9 +207,67 @@ def remote_source(self) -> str: """ return typing.cast(str, self._remote_source) + def new_file(self, name: typing.Optional[str] = None) -> FlyteFile: + """ + This will create a new file under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + # TODO we may want to use - https://github.com/fsspec/universal_pathlib + if not name: + name = UUID(int=random.getrandbits(128)).hex + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteFile(path=new_path) + + def new_dir(self, name: typing.Optional[str] = None) -> FlyteDirectory: + """ + This will create a new folder under the current folder. + If given a name, it will use the name given, otherwise it'll pick a random string. + Collisions are not checked. + """ + if not name: + name = UUID(int=random.getrandbits(128)).hex + + new_path = self.sep.join([str(self.path).rstrip(self.sep), name]) # trim trailing sep if any and join + return FlyteDirectory(path=new_path) + def download(self) -> str: return self.__fspath__() + def crawl( + self, maxdepth: typing.Optional[int] = None, topdown: bool = True, **kwargs + ) -> Generator[Tuple[typing.Union[str, os.PathLike[Any]], typing.Dict[Any, Any]], None, None]: + """ + Crawl returns a generator of all files prefixed by any sub-folders under the given "FlyteDirectory". + if details=True is passed, then it will return a dictionary as specified by fsspec. + + Example: + + >>> list(fd.crawl()) + [("/base", "file1"), ("/base", "dir1/file1"), ("/base", "dir2/file1"), ("/base", "dir1/dir/file1")] + + >>> list(x.crawl(detail=True)) + [('/tmp/test', {'my-dir/ab.py': {'name': '/tmp/test/my-dir/ab.py', 'size': 0, 'type': 'file', + 'created': 1677720780.2318847, 'islink': False, 'mode': 33188, 'uid': 501, 'gid': 0, + 'mtime': 1677720780.2317934, 'ino': 1694329, 'nlink': 1}})] + """ + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_directory: + final_path = self.remote_directory + ctx = FlyteContextManager.current_context() + fs = ctx.file_access.get_filesystem_for_path(final_path) + base_path_len = len(fsspec.core.strip_protocol(final_path)) + 1 # Add additional `/` at the end + for base, _, files in fs.walk(final_path, maxdepth, topdown, **kwargs): + current_base = base[base_path_len:] + if isinstance(files, dict): + for f, v in files.items(): + yield final_path, {os.path.join(current_base, f): v} + else: + for f in files: + yield final_path, os.path.join(current_base, f) + def __repr__(self): return self.path diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 6537f85cae..bb8feb3d9c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -3,12 +3,14 @@ import os import pathlib import typing +from contextlib import contextmanager from dataclasses import dataclass, field from dataclasses_json import config, dataclass_json from marshmallow import fields +from typing_extensions import Annotated, get_args, get_origin -from flytekit.core.context_manager import FlyteContext +from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError from flytekit.loggers import logger from flytekit.models.core.types import BlobType @@ -27,7 +29,9 @@ def noop(): @dataclass_json @dataclass class FlyteFile(os.PathLike, typing.Generic[T]): - path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore + path: typing.Union[str, os.PathLike] = field( + default=None, metadata=config(mm_field=fields.String()) + ) # type: ignore """ Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int exists for Flyte's Integer type) we need to create one so that users can express that their tasks take @@ -148,6 +152,15 @@ def t2() -> flytekit_typing.FlyteFile["csv"]: def extension(cls) -> str: return "" + @classmethod + def new_remote_file(cls, name: typing.Optional[str] = None) -> FlyteFile: + """ + Create a new FlyteFile object with a remote path. + """ + ctx = FlyteContextManager.current_context() + remote_path = ctx.file_access.get_random_remote_path(name) + return cls(path=remote_path) + def __class_getitem__(cls, item: typing.Union[str, typing.Type]) -> typing.Type[FlyteFile]: from . import FileExt @@ -226,6 +239,57 @@ def remote_source(self) -> str: def download(self) -> str: return self.__fspath__() + @contextmanager + def open( + self, + mode: str, + cache_type: typing.Optional[str] = None, + cache_options: typing.Optional[typing.Dict[str, typing.Any]] = None, + ): + """ + Returns a streaming File handle + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with ff.open("rb", cache_type="readahead", cache={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + Alternatively + + .. code-block:: python + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.name) + with fsspec.open(f"readahead::{ff.remote_path}", "rb", readahead={}) as r: + with new_file.open("wb") as w: + w.write(r.read()) + return new_file + + + :param mode: str Open mode like 'rb', 'rt', 'wb', ... + :param cache_type: optional str Specify if caching is to be used. Cache protocol can be ones supported by + fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, + especially useful for large file reads + :param cache_options: optional Dict[str, Any] Refer to fsspec caching options. This is strongly coupled to the + cache_protocol + """ + ctx = FlyteContextManager.current_context() + final_path = self.path + if self.remote_source: + final_path = self.remote_source + elif self.remote_path: + final_path = self.remote_path + fs = ctx.file_access.get_filesystem_for_path(final_path) + f = fs.open(final_path, mode, cache_type=cache_type, cache_options=cache_options) + yield f + f.close() + def __repr__(self): return self.path @@ -272,6 +336,10 @@ def to_literal( if python_val is None: raise TypeTransformerFailedError("None value cannot be converted to a file.") + # Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type + if get_origin(python_type) is Annotated: + python_type = get_args(python_type)[0] + if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)): raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike") diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index c380bcc481..ac6b71ba38 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -186,7 +186,6 @@ class FlyteSchema(object): """ This is the main schema class that users should use. """ - logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") @classmethod def columns(cls) -> typing.Dict[str, typing.Type]: @@ -203,6 +202,7 @@ def format(cls) -> SchemaFormat: def __class_getitem__( cls, columns: typing.Dict[str, typing.Type], fmt: SchemaFormat = SchemaFormat.PARQUET ) -> Type[FlyteSchema]: + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if columns is None: return FlyteSchema @@ -240,6 +240,7 @@ def __init__( supported_mode: SchemaOpenMode = SchemaOpenMode.WRITE, downloader: typing.Optional[typing.Callable] = None, ): + logger.warning("FlyteSchema is deprecated, use Structured Dataset instead.") if supported_mode == SchemaOpenMode.READ and remote_path is None: raise ValueError("To create a FlyteSchema in read mode, remote_path is required") if ( diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index ae3e8a00d9..c8f4ef3baa 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -62,22 +62,6 @@ def encode( structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - def ddencode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - df.to_parquet(local_path, coerce_timestamps="us", allow_truncated_timestamps=False) - ctx.file_access.upload_directory(local_dir, path) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) - class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): def __init__(self): @@ -101,20 +85,6 @@ def decode( kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) - def dcccecode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pd.read_parquet(local_dir, columns=columns) - return pd.read_parquet(local_dir) - class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 0d6788ac92..014b88f4f3 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -355,8 +355,12 @@ def simple_pod_task(i: int): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.test_pod", "task-name", diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 880036f636..1b33ad2923 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -1,13 +1,17 @@ import os +import random import shutil import tempfile +from uuid import UUID import fsspec import mock import pytest from flytekit.configuration import Config, S3Config +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.types.directory.types import FlyteDirectory local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -99,6 +103,8 @@ def source_folder(): nested_dir = os.path.join(src_dir, "nested") local.mkdir(nested_dir) local.touch(os.path.join(src_dir, "original.txt")) + with open(os.path.join(src_dir, "original.txt"), "w") as fh: + fh.write("hello original") local.touch(os.path.join(nested_dir, "more.txt")) yield src_dir shutil.rmtree(parent_temp) @@ -213,3 +219,112 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): kwargs = s3_setup_args(S3Config.auto()) # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default assert kwargs == {} + + +def test_crawl_local_nt(source_folder): + """ + running this to see what it prints + """ + if os.name != "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + print(f"NT split {split}") + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + print(f"NT files joined {files}") + + +def test_crawl_local_non_nt(source_folder): + """ + crawl on the source folder fixture should return for example + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'original.txt') + ('/var/folders/jx/54tww2ls58n8qtlp9k31nbd80000gp/T/tmpp14arygf/source/', 'nested/more.txt') + """ + if os.name == "nt": # don't + return + source_folder = os.path.join(source_folder, "") # ensure there's a trailing / or \ + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + split = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in split] + assert set(split) == {(source_folder, "original.txt"), (source_folder, os.path.join("nested", "more.txt"))} + expected = {os.path.join(source_folder, "original.txt"), os.path.join(source_folder, "nested", "more.txt")} + assert set(files) == expected + + # Test crawling a directory without trailing / or \ + source_folder = source_folder[:-1] + fd = FlyteDirectory(path=source_folder) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + + # Test crawling a single file + fd = FlyteDirectory(path=os.path.join(source_folder, "original.txt")) + res = fd.crawl() + files = [os.path.join(x, y) for x, y in res] + assert len(files) == 0 + + +@pytest.mark.sandbox_test +def test_crawl_s3(source_folder): + """ + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'original.txt') + ('s3://my-s3-bucket/testdata/5b31492c032893b515650f8c76008cf7', 'nested/more.txt') + """ + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + s3_random_target = provider.get_random_remote_directory() + provider.put_data(source_folder, s3_random_target, is_multipart=True) + ctx = FlyteContextManager.current_context() + expected = {f"{s3_random_target}/original.txt", f"{s3_random_target}/nested/more.txt"} + + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory(path=s3_random_target) + res = fd.crawl() + res = [(x, y) for x, y in res] + files = [os.path.join(x, y) for x, y in res] + assert set(files) == expected + assert set(res) == {(s3_random_target, "original.txt"), (s3_random_target, os.path.join("nested", "more.txt"))} + + fd_file = FlyteDirectory(path=f"{s3_random_target}/original.txt") + res = fd_file.crawl() + files = [r for r in res] + assert len(files) == 1 + + +@pytest.mark.sandbox_test +def test_walk_local_copy_to_s3(source_folder): + dc = Config.for_sandbox().data_config + explicit_empty_folder = UUID(int=random.getrandbits(128)).hex + raw_output_path = f"s3://my-s3-bucket/testdata/{explicit_empty_folder}" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output_path, data_config=dc) + + ctx = FlyteContextManager.current_context() + local_fd = FlyteDirectory(path=source_folder) + local_fd_crawl = local_fd.crawl() + local_fd_crawl = [x for x in local_fd_crawl] + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + fd = FlyteDirectory.new_remote() + assert raw_output_path in fd.path + + # Write source folder files to new remote path + for root_path, suffix in local_fd_crawl: + new_file = fd.new_file(suffix) # noqa + with open(os.path.join(root_path, suffix), "rb") as r: # noqa + with new_file.open("w") as w: + print(f"Writing, t {type(w)} p {new_file.path} |{suffix}|") + w.write(str(r.read())) + + new_crawl = fd.crawl() + new_suffixes = [y for x, y in new_crawl] + assert len(new_suffixes) == 2 # should have written two files diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index e2123222e0..b7f0a1aeee 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock import pytest +from typing_extensions import Annotated import flytekit.configuration -from flytekit.configuration import Image, ImageConfig -from flytekit.core import context_manager -from flytekit.core.context_manager import ExecutionState +from flytekit.configuration import Config, Image, ImageConfig +from flytekit.core.context_manager import ExecutionState, FlyteContextManager from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine @@ -81,11 +82,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() - # print(f"Random: {random_dir}") + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -108,10 +108,10 @@ def t1() -> FlyteFile: def my_wf() -> FlyteFile: return t1() - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): top_level_files = os.listdir(random_dir) assert len(top_level_files) == 1 # the flytekit_local folder @@ -137,12 +137,12 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir print(f"Random {random_dir}") fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit folder @@ -189,11 +189,11 @@ def my_wf() -> FlyteFile: return t1() # This creates a random directory that we know is empty. - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + random_dir = FlyteContextManager.current_context().file_access.get_random_local_directory() # Creating a new FileAccessProvider will add two folderst to the random dir fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "mock_remote")) - ctx = context_manager.FlyteContext.current_context() - with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)): + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context(ctx.with_file_access(fs)): working_dir = os.listdir(random_dir) assert len(working_dir) == 1 # the local_flytekit dir @@ -243,8 +243,8 @@ def dyn(in1: FlyteFile): fd = FlyteFile("s3://anything") - with context_manager.FlyteContextManager.with_context( - context_manager.FlyteContextManager.current_context().with_serialization_settings( + with FlyteContextManager.with_context( + FlyteContextManager.current_context().with_serialization_settings( flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", @@ -254,8 +254,8 @@ def dyn(in1: FlyteFile): ) ) ): - ctx = context_manager.FlyteContextManager.current_context() - with context_manager.FlyteContextManager.with_context( + ctx = FlyteContextManager.current_context() + with FlyteContextManager.with_context( ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)) ) as ctx: lit = TypeEngine.to_literal( @@ -433,3 +433,59 @@ def wf(path: str) -> os.PathLike: return t2(ff=n1) assert flyte_tmp_dir in wf(path="s3://somewhere").path + + +def test_flyte_file_annotated_hashmethod(local_dummy_file): + def calc_hash(ff: FlyteFile) -> str: + return str(ff.path) + + @task + def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]: + return FlyteFile(path) + + @workflow + def wf(path: str) -> None: + t1(path=path) + + wf(path=local_dummy_file) + + +@pytest.mark.sandbox_test +def test_file_open_things(): + @task + def write_this_file_to_s3() -> FlyteFile: + ctx = FlyteContextManager.current_context() + dest = ctx.file_access.get_random_remote_path() + ctx.file_access.put(__file__, dest) + return FlyteFile(path=dest) + + @task + def copy_file(ff: FlyteFile) -> FlyteFile: + new_file = FlyteFile.new_remote_file(ff.remote_path) + with ff.open("r") as r: + with new_file.open("w") as w: + w.write(r.read()) + return new_file + + @task + def print_file(ff: FlyteFile): + with open(ff, "r") as fh: + print(len(fh.readlines())) + + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as new_sandbox: + provider = FileAccessProvider( + local_sandbox_dir=new_sandbox, raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + ctx = FlyteContextManager.current_context() + local = ctx.file_access.get_filesystem("file") # get a local file system. + with FlyteContextManager.with_context(ctx.with_file_access(provider)): + f = write_this_file_to_s3() + copy_file(ff=f) + files = local.find(new_sandbox) + # copy_file was done via streaming so no files should have been written + assert len(files) == 0 + print_file(ff=f) + # print_file uses traditional download semantics so now a file should have been created + files = local.find(new_sandbox) + assert len(files) == 1 diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index 95927873d0..d032aca2d1 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -1,3 +1,4 @@ +import functools import typing from collections import OrderedDict @@ -6,7 +7,7 @@ import flytekit.configuration from flytekit import LaunchPlan, map_task from flytekit.configuration import Image, ImageConfig -from flytekit.core.map_task import MapPythonTask +from flytekit.core.map_task import MapPythonTask, MapTaskResolver from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow from flytekit.tools.translator import get_serializable @@ -36,6 +37,11 @@ def t2(a: int) -> str: return str(b) +@task(cache=True, cache_version="1") +def t3(a: int, b: str, c: float) -> str: + pass + + # This test is for documentation. def test_map_docs(): # test_map_task_start @@ -87,8 +93,12 @@ def test_serialization(serialization_settings): "--prev-checkpoint", "{{.prevCheckpointPrefix}}", "--resolver", - "flytekit.core.python_auto_container.default_task_resolver", + "MapTaskResolver", "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", "task-module", "tests.flytekit.unit.core.test_map_task", "task-name", @@ -177,15 +187,42 @@ def test_inputs_outputs_length(): def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" - with pytest.raises(ValueError): - _ = map_task(many_inputs) + m = map_task(many_inputs) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_24c08b3a2f9c2e389ad9fc6a03482cf9" + r_m = MapPythonTask(many_inputs) + assert str(r_m.python_interface) == str(m.python_interface) + + p1 = functools.partial(many_inputs, c=1.0) + m = map_task(p1) + assert m.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_697aa7389996041183cf6cfd102be4f7" + r_m = MapPythonTask(many_inputs, bound_inputs=set("c")) + assert str(r_m.python_interface) == str(m.python_interface) + + p2 = functools.partial(p1, b="hello") + m = map_task(p2) + assert m.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_cc18607da7494024a402a5fa4b3ea5c6" + r_m = MapPythonTask(many_inputs, bound_inputs={"c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + p3 = functools.partial(p2, a=1) + m = map_task(p3) + assert m.python_interface.inputs == {"a": int, "b": str, "c": float} + assert m.name == "tests.flytekit.unit.core.test_map_task.map_many_inputs_52fe80b04781ea77ef6f025f4b49abef" + r_m = MapPythonTask(many_inputs, bound_inputs={"a", "c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + with pytest.raises(TypeError): + m(a=[1, 2, 3]) @task def many_outputs(a: int) -> (int, str): return a, f"{a}" with pytest.raises(ValueError): - _ = map_task(many_inputs) + _ = map_task(many_outputs) def test_map_task_metadata(): @@ -194,3 +231,34 @@ def test_map_task_metadata(): assert mapped_1.metadata is map_meta mapped_2 = map_task(t2) assert mapped_2.metadata is t2.metadata + + +def test_map_task_resolver(serialization_settings): + list_outputs = {"o0": typing.List[str]} + mt = map_task(t3) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": typing.List[str], "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + assert mtr.name() == "MapTaskResolver" + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello", c=1.0)) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": float} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs + + mt = map_task(functools.partial(t3, b="hello")) + assert mt.python_interface.inputs == {"a": typing.List[int], "b": str, "c": typing.List[float]} + assert mt.python_interface.outputs == list_outputs + mtr = MapTaskResolver() + args = mtr.loader_args(serialization_settings, mt) + t = mtr.load_task(loader_args=args) + assert t.python_interface.inputs == mt.python_interface.inputs + assert t.python_interface.outputs == mt.python_interface.outputs diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py new file mode 100644 index 0000000000..0a78c825f8 --- /dev/null +++ b/tests/flytekit/unit/core/test_partials.py @@ -0,0 +1,181 @@ +import typing +from collections import OrderedDict +from functools import partial + +import pandas as pd +import pytest + +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig +from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.map_task import MapTaskResolver, map_task +from flytekit.core.task import TaskMetadata, task +from flytekit.core.workflow import workflow +from flytekit.tools.translator import gather_dependent_entities, get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + +def test_basics_1(): + @task + def t1(a: int, b: str, c: float) -> int: + return a + len(b) + int(c) + + outside_p = partial(t1, b="hello", c=3.14) + + @workflow + def my_wf_1(a: int) -> typing.Tuple[int, int]: + inner_partial = partial(t1, b="world", c=2.7) + out = outside_p(a=a) + inside = inner_partial(a=a) + return out, inside + + with pytest.raises(Exception): + get_serializable(OrderedDict(), serialization_settings, outside_p) + + # check the od todo + od = OrderedDict() + wf_1_spec = get_serializable(od, serialization_settings, my_wf_1) + tts, wspecs, lps = gather_dependent_entities(od) + tts = [t for t in tts.values()] + assert len(tts) == 1 + assert len(wf_1_spec.template.nodes) == 2 + assert wf_1_spec.template.nodes[0].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[1].task_node.reference_id.name == tts[0].id.name + assert wf_1_spec.template.nodes[0].inputs[0].binding.promise.var == "a" + assert wf_1_spec.template.nodes[0].inputs[1].binding.scalar is not None + assert wf_1_spec.template.nodes[0].inputs[2].binding.scalar is not None + + @task + def get_str() -> str: + return "got str" + + bind_c = partial(t1, c=2.7) + + @workflow + def my_wf_2(a: int) -> int: + s = get_str() + inner_partial = partial(bind_c, b=s) + inside = inner_partial(a=a) + return inside + + wf_2_spec = get_serializable(OrderedDict(), serialization_settings, my_wf_2) + assert len(wf_2_spec.template.nodes) == 2 + + +def test_map_task_types(): + @task(cache=True, cache_version="1") + def t3(a: int, b: str, c: float) -> str: + return str(a) + b + str(c) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=3.14) + t3_bind_c2 = partial(t3_bind_b2, c=2.78) + + mt1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + + @task + def print_lists(i: typing.List[str], j: typing.List[str]): + print(f"First: {i}") + print(f"Second: {j}") + + @workflow + def wf_out(a: typing.List[int]): + i = mt1(a=a) + j = mt2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_out(a=[1, 2]) + + @workflow + def wf_in(a: typing.List[int]): + mt_in1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + i = mt_in1(a=a) + j = mt_in2(a=[3, 4, 5]) + print_lists(i=i, j=j) + + wf_in(a=[1, 2]) + + od = OrderedDict() + wf_spec = get_serializable(od, serialization_settings, wf_in) + tts, _, _ = gather_dependent_entities(od) + assert len(tts) == 2 # one map task + the print task + assert ( + wf_spec.template.nodes[0].task_node.reference_id.name == wf_spec.template.nodes[1].task_node.reference_id.name + ) + assert wf_spec.template.nodes[0].inputs[0].binding.promise is not None # comes from wf input + assert wf_spec.template.nodes[1].inputs[0].binding.collection is not None # bound to static list + assert wf_spec.template.nodes[1].inputs[1].binding.scalar is not None # these are bound + assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None + + +def test_everything(): + @task + def get_static_list() -> typing.List[float]: + return [3.14, 2.718] + + @task + def get_list_of_pd(s: int) -> typing.List[pd.DataFrame]: + df1 = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + df2 = pd.DataFrame({"Name": ["Rachel", "Eve", "Mary"], "Age": [22, 23, 24]}) + if s == 2: + return [df1, df2] + else: + return [df1, df2, df1] + + @task + def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.DataFrame) -> str: + return str(a) + f"pdsize{len(a2)}" + b + str(c) + "&&" + str(d) + + t3_bind_b1 = partial(t3, b="hello") + t3_bind_b2 = partial(t3, b="world") + t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) + + mt1 = map_task(t3_bind_c1) + + mr = MapTaskResolver() + aa = mr.loader_args(serialization_settings, mt1) + # Check bound vars + aa = aa[1].split(",") + aa.sort() + assert aa == ["b", "c", "d"] + + @task + def print_lists(i: typing.List[str], j: typing.List[str]) -> str: + print(f"First: {i}") + print(f"Second: {j}") + return f"{i}-{j}" + + @dynamic + def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: + i = mt1(a=a, a2=a2) + t3_bind_c2 = partial(t3_bind_b2, c=[1.0, 2.0, 3.0], d=sl) + mt_in2 = map_task(t3_bind_c2) + dfs = get_list_of_pd(s=3) + j = mt_in2(a=[3, 4, 5], a2=dfs) + return print_lists(i=i, j=j) + + @workflow + def wf_dt(a: typing.List[int]) -> str: + sl = get_static_list() + dfs = get_list_of_pd(s=2) + return dt1(a=a, a2=dfs, sl=sl) + + print(wf_dt(a=[1, 2])) + assert ( + wf_dt(a=[1, 2]) + == "['1pdsize2hello[6.674, 1.618, 6.626]&&[1.0]', '2pdsize3hello[6.674, 1.618, 6.626]&&[1.0]']-['3pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '4pdsize3world[1.0, 2.0, 3.0]&&[3.14, 2.718]', '5pdsize2world[1.0, 2.0, 3.0]&&[3.14, 2.718]']" + ) diff --git a/tests/flytekit/unit/core/tracker/d.py b/tests/flytekit/unit/core/tracker/d.py index 9385b0f08d..c84e36fe59 100644 --- a/tests/flytekit/unit/core/tracker/d.py +++ b/tests/flytekit/unit/core/tracker/d.py @@ -9,3 +9,7 @@ def tasks(): @task def foo(): pass + + +def inner_function(a: str) -> str: + return "hello" diff --git a/tests/flytekit/unit/core/tracker/test_tracking.py b/tests/flytekit/unit/core/tracker/test_tracking.py index 33ae18acd5..b33725436d 100644 --- a/tests/flytekit/unit/core/tracker/test_tracking.py +++ b/tests/flytekit/unit/core/tracker/test_tracking.py @@ -79,3 +79,10 @@ def test_extract_task_module(test_input, expected): except Exception: FeatureFlags.FLYTE_PYTHON_PACKAGE_ROOT = old raise + + +local_task = task(d.inner_function) + + +def test_local_task_wrap(): + assert local_task.instantiated_in == "tests.flytekit.unit.core.tracker.test_tracking" diff --git a/tests/flytekit/unit/extras/sqlite3/chinook.zip b/tests/flytekit/unit/extras/sqlite3/chinook.zip new file mode 100644 index 0000000000..6dd568fa61 Binary files /dev/null and b/tests/flytekit/unit/extras/sqlite3/chinook.zip differ diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 40fc94a3d2..f8014f244b 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,3 +1,5 @@ +import os + import pandas import pytest @@ -10,8 +12,7 @@ from flytekit.types.schema import FlyteSchema ctx = context_manager.FlyteContextManager.current_context() -EXAMPLE_DB = ctx.file_access.get_random_local_path("chinook.zip") -ctx.file_access.get_data("https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip", EXAMPLE_DB) +EXAMPLE_DB = os.path.join(os.path.dirname(os.path.realpath(__file__)), "chinook.zip") # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( diff --git a/tests/flytekit/unit/types/directory/__init__.py b/tests/flytekit/unit/types/directory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/types/directory/test_types.py b/tests/flytekit/unit/types/directory/test_types.py new file mode 100644 index 0000000000..199b788733 --- /dev/null +++ b/tests/flytekit/unit/types/directory/test_types.py @@ -0,0 +1,31 @@ +import mock + +from flytekit import FlyteContext +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile + + +def test_new_file_dir(): + fd = FlyteDirectory(path="s3://my-bucket") + assert fd.sep == "/" + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + fd = FlyteDirectory(path="s3://my-bucket/") + inner_dir = fd.new_dir("test") + assert inner_dir.path == "s3://my-bucket/test" + f = inner_dir.new_file("test") + assert isinstance(f, FlyteFile) + assert f.path == "s3://my-bucket/test/test" + + +def test_new_remote_dir(): + fd = FlyteDirectory.new_remote() + assert FlyteContext.current_context().file_access.raw_output_prefix in fd.path + + +@mock.patch("flytekit.types.directory.types.os.name", "nt") +def test_sep_nt(): + fd = FlyteDirectory(path="file://mypath") + assert fd.sep == "\\" + fd = FlyteDirectory(path="s3://mypath") + assert fd.sep == "/" diff --git a/tests/flytekit/unit/types/file/__init__.py b/tests/flytekit/unit/types/file/__init__.py new file mode 100644 index 0000000000..e69de29bb2