diff --git a/Dockerfile.dev b/Dockerfile.dev new file mode 100644 index 0000000000..f6baf63896 --- /dev/null +++ b/Dockerfile.dev @@ -0,0 +1,38 @@ +# This Dockerfile is here to help with end-to-end testing +# From flytekit +# $ docker build -f Dockerfile.dev --build-arg PYTHON_VERSION=3.10 -t localhost:30000/flytekittest:someversion . +# $ docker push localhost:30000/flytekittest:someversion +# From your test user code +# $ pyflyte run --image localhost:30000/flytekittest:someversion + +ARG PYTHON_VERSION +FROM python:${PYTHON_VERSION}-slim-buster + +MAINTAINER Flyte Team +LABEL org.opencontainers.image.source https://github.com/flyteorg/flytekit + +WORKDIR /root +ENV PYTHONPATH /root + +ARG VERSION +ARG DOCKER_IMAGE + +RUN apt-get update && apt-get install build-essential vim -y + +COPY . /code/flytekit +WORKDIR /code/flytekit + +# Pod tasks should be exposed in the default image +RUN pip install -e . +RUN pip install -e plugins/flytekit-k8s-pod +RUN pip install -e plugins/flytekit-deck-standard +RUN pip install scikit-learn + +ENV PYTHONPATH "/code/flytekit:/code/flytekit/plugins/flytekit-k8s-pod:/code/flytekit/plugins/flytekit-deck-standard:" + +WORKDIR /root +RUN useradd -u 1000 flytekit +RUN chown flytekit: /root +USER flytekit + +ENV FLYTE_INTERNAL_IMAGE "$DOCKER_IMAGE" diff --git a/doc-requirements.txt b/doc-requirements.txt index 98a84f41c9..19f20af9fc 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -216,7 +216,7 @@ frozenlist==1.3.3 # via # aiosignal # ray -fsspec==2023.1.0 +fsspec==2023.3.0 # via # -r doc-requirements.in # dask diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 0a992719d9..c2fc11816c 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -205,7 +205,6 @@ from flytekit.core.condition import conditional from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod @@ -223,7 +222,6 @@ from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck from flytekit.extras import pytorch, sklearn, tensorflow -from flytekit.extras.persistence import GCSPersistence, HttpPersistence, S3Persistence from flytekit.loggers import logger from flytekit.models.common import Annotations, AuthRole, Labels from flytekit.models.core.execution import WorkflowExecutionPhase diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index afed857a26..6485f3a9d5 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -352,7 +352,7 @@ class PlatformConfig(object): This object contains the settings to talk to a Flyte backend (the DNS location of your Admin server basically). :param endpoint: DNS for Flyte backend - :param insecure: Whether to use SSL + :param insecure: Whether or not to use SSL :param insecure_skip_verify: Whether to skip SSL certificate verification :param console_endpoint: endpoint for console if different from Flyte backend :param command: This command is executed to return a token using an external process diff --git a/flytekit/core/checkpointer.py b/flytekit/core/checkpointer.py index c1eb933ec6..4b4cfd16f3 100644 --- a/flytekit/core/checkpointer.py +++ b/flytekit/core/checkpointer.py @@ -126,7 +126,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): fa.upload_directory(str(cp), self._checkpoint_dest) else: fname = cp.stem + cp.suffix - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, fname) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), fname]) fa.upload(str(cp), rpath) return @@ -138,7 +138,7 @@ def save(self, cp: typing.Union[Path, str, io.BufferedReader]): with dest_cp.open("wb") as f: f.write(cp.read()) - rpath = fa._default_remote.construct_path(False, False, self._checkpoint_dest, self.TMP_DST_PATH) + rpath = fa._default_remote.sep.join([str(self._checkpoint_dest), self.TMP_DST_PATH]) fa.upload(str(dest_cp), rpath) def read(self) -> typing.Optional[bytes]: diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index fc8915e338..63914c13b2 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -84,7 +84,7 @@ class Builder(object): decks: List[Deck] raw_output_prefix: Optional[str] = None execution_id: typing.Optional[_identifier.WorkflowExecutionIdentifier] = None - working_dir: typing.Optional[utils.AutoDeletingTempDir] = None + working_dir: typing.Optional[str] = None checkpoint: typing.Optional[Checkpoint] = None execution_date: typing.Optional[datetime] = None logging: Optional[_logging.Logger] = None @@ -202,12 +202,10 @@ def raw_output_prefix(self) -> str: return self._raw_output_prefix @property - def working_directory(self) -> utils.AutoDeletingTempDir: + def working_directory(self) -> str: """ A handle to a special working directory for easily producing temporary files. - TODO: Usage examples - TODO: This does not always return a AutoDeletingTempDir """ return self._working_directory diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index d48ce45ce1..8fb73ebd8c 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -21,296 +21,46 @@ UnsupportedPersistenceOp """ - import os import pathlib -import re -import shutil -import sys import tempfile import typing -from abc import abstractmethod -from shutil import copyfile -from typing import Dict, Union +from typing import Union, cast from uuid import UUID +import fsspec +from fsspec.utils import get_protocol + +from flytekit import configuration from flytekit.configuration import DataConfig from flytekit.core.utils import PerformanceTimer -from flytekit.exceptions.user import FlyteAssertion, FlyteValueException +from flytekit.exceptions.user import FlyteAssertion from flytekit.interfaces.random import random from flytekit.loggers import logger -CURRENT_PYTHON = sys.version_info[:2] -THREE_SEVEN = (3, 7) - - -class UnsupportedPersistenceOp(Exception): - """ - This exception is raised for all methods when a method is not supported by the data persistence layer - """ - - def __init__(self, message: str): - super(UnsupportedPersistenceOp, self).__init__(message) - - -class DataPersistence(object): - """ - Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture - """ - - def __init__(self, name: str = "", default_prefix: typing.Optional[str] = None, **kwargs): - self._name = name - self._default_prefix = default_prefix - - @property - def name(self) -> str: - return self._name - - @property - def default_prefix(self) -> typing.Optional[str]: - return self._default_prefix - - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - """ - Returns true if the given path exists, else false - """ - raise UnsupportedPersistenceOp(f"Listing a directory is not supported by the persistence plugin {self.name}") - - @abstractmethod - def exists(self, path: str) -> bool: - """ - Returns true if the given path exists, else false - """ - pass - - @abstractmethod - def get(self, from_path: str, to_path: str, recursive: bool = False): - """ - Retrieves data from from_path and writes to the given to_path (to_path is locally accessible) - """ - pass - - @abstractmethod - def put(self, from_path: str, to_path: str, recursive: bool = False): - """ - Stores data from from_path and writes to the given to_path (from_path is locally accessible) - """ - pass - - @abstractmethod - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - """ - if add_protocol is true then is prefixed else - Constructs a path in the format *args - delim is dependent on the storage medium. - each of the args is joined with the delim - """ - pass - +# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 +# for key and secret +_FSSPEC_S3_KEY_ID = "key" +_FSSPEC_S3_SECRET = "secret" +_ANON = "anon" -class DataPersistencePlugins(object): - """ - DataPersistencePlugins is the core plugin registry that stores all DataPersistence plugins. To add a new plugin use - - .. code-block:: python - DataPersistencePlugins.register_plugin("s3:/", DataPersistence(), force=True|False) - - These plugins should always be registered. Follow the plugin registration guidelines to auto-discover your plugins. - """ +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): + kwargs = {} + if s3_cfg.access_key_id: + kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - _PLUGINS: Dict[str, typing.Type[DataPersistence]] = {} + if s3_cfg.secret_access_key: + kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key - @classmethod - def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], force: bool = False): - """ - Registers the supplied plugin for the specified protocol if one does not already exist. - If one exists and force is default or False, then a TypeError is raised. - If one does not exist then it is registered - If one exists, but force == True then the existing plugin is overridden - """ - if protocol in cls._PLUGINS: - p = cls._PLUGINS[protocol] - if p == plugin: - return - if not force: - raise TypeError( - f"Cannot register plugin {plugin.name} for protocol {protocol} as plugin {p.name} is already" - f" registered for the same protocol. You can force register the new plugin by passing force=True" - ) + # S3fs takes this as a special arg + if s3_cfg.endpoint is not None: + kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} - cls._PLUGINS[protocol] = plugin + if anonymous: + kwargs[_ANON] = True - @staticmethod - def get_protocol(url: str) -> str: - # copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391 - parts = re.split(r"(\:\:|\://)", url, 1) - if len(parts) > 1: - return parts[0] - logger.info("Setting protocol to file") - return "file" - - @classmethod - def find_plugin(cls, path: str) -> typing.Type[DataPersistence]: - """ - Returns a plugin for the given protocol, else raise a TypeError - """ - for k, p in cls._PLUGINS.items(): - if cls.get_protocol(path) == k.replace("://", "") or path.startswith(k): - return p - raise TypeError(f"No plugin found for matching protocol of path {path}") - - @classmethod - def print_all_plugins(cls): - """ - Prints all the plugins and their associated protocoles - """ - for k, p in cls._PLUGINS.items(): - print(f"Plugin {p.name} registered for protocol {k}") - - @classmethod - def is_supported_protocol(cls, protocol: str) -> bool: - """ - Returns true if the given protocol is has a registered plugin for it - """ - return protocol in cls._PLUGINS - - @classmethod - def supported_protocols(cls) -> typing.List[str]: - return [k for k in cls._PLUGINS.keys()] - - -class DiskPersistence(DataPersistence): - """ - The simplest form of persistence that is available with default flytekit - Disk-based persistence. - This will store all data locally and retrieve the data from local. This is helpful for local execution and simulating - runs. - """ - - PROTOCOL = "file://" - - def __init__(self, default_prefix: typing.Optional[str] = None, **kwargs): - super().__init__(name="local", default_prefix=default_prefix, **kwargs) - - @staticmethod - def _make_local_path(path): - if not os.path.exists(path): - try: - pathlib.Path(path).mkdir(parents=True, exist_ok=True) - except OSError: # Guard against race condition - if not os.path.isdir(path): - raise - - @staticmethod - def strip_file_header(path: str) -> str: - """ - Drops file:// if it exists from the file - """ - if path.startswith("file://"): - return path.replace("file://", "", 1) - return path - - def listdir(self, path: str, recursive: bool = False) -> typing.Generator[str, None, None]: - if not recursive: - files = os.listdir(self.strip_file_header(path)) - for f in files: - yield f - return - - for root, subdirs, files in os.walk(self.strip_file_header(path)): - for f in files: - yield os.path.join(root, f) - return - - def exists(self, path: str): - return os.path.exists(self.strip_file_header(path)) - - def copy_tree(self, from_path: str, to_path: str): - # TODO: Remove this code after support for 3.7 is dropped and inline this function back - # 3.7 doesn't have dirs_exist_ok - if CURRENT_PYTHON == THREE_SEVEN: - tp = pathlib.Path(self.strip_file_header(to_path)) - if tp.exists(): - if not tp.is_dir(): - raise FlyteValueException(tp, f"Target {tp} exists but is not a dir") - files = os.listdir(tp) - if len(files) != 0: - logger.debug(f"Deleting existing target dir {tp} with files {files}") - shutil.rmtree(tp) - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path)) - else: - # copytree will overwrite existing files in the to_path - shutil.copytree(self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True) - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - if from_path != to_path: - if recursive: - self.copy_tree(from_path, to_path) - else: - # Emulate s3's flat storage by automatically creating directory path - self._make_local_path(os.path.dirname(self.strip_file_header(to_path))) - # Write the object to a local file in the temp local folder - copyfile(self.strip_file_header(from_path), self.strip_file_header(to_path)) - - def construct_path(self, _: bool, add_prefix: bool, *args: str) -> str: - # Ignore add_protocol for now. Only complicates things - if add_prefix: - prefix = self.default_prefix if self.default_prefix else "" - return os.path.join(prefix, *args) - return os.path.join(*args) - - -def stringify_path(filepath): - """ - Copied from `filesystem_spec `__ - - Attempt to convert a path-like object to a string. - Parameters - ---------- - filepath: object to be converted - Returns - ------- - filepath_str: maybe a string version of the object - Notes - ----- - Objects supporting the fspath protocol (Python 3.6+) are coerced - according to its __fspath__ method. - For backwards compatibility with older Python version, pathlib.Path - objects are specially coerced. - Any other object is passed through unchanged, which includes bytes, - strings, buffers, or anything else that's not even path-like. - """ - if isinstance(filepath, str): - return filepath - elif hasattr(filepath, "__fspath__"): - return filepath.__fspath__() - elif isinstance(filepath, pathlib.Path): - return str(filepath) - elif hasattr(filepath, "path"): - return filepath.path - else: - return filepath - - -def split_protocol(urlpath): - """ - Copied from `filesystem_spec `__ - Return protocol, path pair - """ - urlpath = stringify_path(urlpath) - if "://" in urlpath: - protocol, path = urlpath.split("://", 1) - if len(protocol) > 1: - # excludes Windows paths - return protocol, path - return None, urlpath + return kwargs class FileAccessProvider(object): @@ -335,13 +85,18 @@ def __init__( local_sandbox_dir_appended = os.path.join(local_sandbox_dir, "local_flytekit") self._local_sandbox_dir = pathlib.Path(local_sandbox_dir_appended) self._local_sandbox_dir.mkdir(parents=True, exist_ok=True) - self._local = DiskPersistence(default_prefix=local_sandbox_dir_appended) + self._local = fsspec.filesystem(None) - self._default_remote = DataPersistencePlugins.find_plugin(raw_output_prefix)( - default_prefix=raw_output_prefix, data_config=data_config - ) - self._raw_output_prefix = raw_output_prefix self._data_config = data_config if data_config else DataConfig.auto() + self._default_protocol = get_protocol(raw_output_prefix) + self._default_remote = cast(fsspec.AbstractFileSystem, self.get_filesystem(self._default_protocol)) + if os.name == "nt" and raw_output_prefix.startswith("file://"): + raise FlyteAssertion("Cannot use the file:// prefix on Windows.") + self._raw_output_prefix = ( + raw_output_prefix + if raw_output_prefix.endswith(self.sep(self._default_remote)) + else raw_output_prefix + self.sep(self._default_remote) + ) @property def raw_output_prefix(self) -> str: @@ -351,38 +106,120 @@ def raw_output_prefix(self) -> str: def data_config(self) -> DataConfig: return self._data_config + def get_filesystem( + self, protocol: typing.Optional[str] = None, anonymous: bool = False + ) -> typing.Optional[fsspec.AbstractFileSystem]: + if not protocol: + return self._default_remote + kwargs = {} # type: typing.Dict[str, typing.Any] + if protocol == "file": + kwargs = {"auto_mkdir": True} + elif protocol == "s3": + kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) + return fsspec.filesystem(protocol, **kwargs) # type: ignore + elif protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. + if anonymous: + return None + + return fsspec.filesystem(protocol, **kwargs) # type: ignore + + def get_filesystem_for_path(self, path: str = "", anonymous: bool = False) -> fsspec.AbstractFileSystem: + protocol = get_protocol(path) + return self.get_filesystem(protocol, anonymous=anonymous) + @staticmethod def is_remote(path: Union[str, os.PathLike]) -> bool: """ - Deprecated. Lets find a replacement + Deprecated. Let's find a replacement """ - protocol, _ = split_protocol(path) + protocol = get_protocol(path) if protocol is None: return False return protocol != "file" @property def local_sandbox_dir(self) -> os.PathLike: + """ + This is a context based temp dir. + """ return self._local_sandbox_dir @property - def local_access(self) -> DiskPersistence: + def local_access(self) -> fsspec.AbstractFileSystem: return self._local - def construct_random_path( - self, persist: DataPersistence, file_path_or_file_name: typing.Optional[str] = None - ) -> str: + @staticmethod + def strip_file_header(path: str, trim_trailing_sep: bool = False) -> str: """ - Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name + Drops file:// if it exists from the file """ - key = UUID(int=random.getrandbits(128)).hex - if file_path_or_file_name: - _, tail = os.path.split(file_path_or_file_name) - if tail: - return persist.construct_path(False, True, key, tail) - else: - logger.warning(f"No filename detected in {file_path_or_file_name}, generating random path") - return persist.construct_path(False, True, key) + if path.startswith("file://"): + return path.replace("file://", "", 1) + return path + + @staticmethod + def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: + f = os.path.join(f, "") + t = os.path.join(t, "") + return f, t + + def sep(self, file_system: typing.Optional[fsspec.AbstractFileSystem]) -> str: + if file_system is None or file_system.protocol == "file": + return os.sep + return file_system.sep + + def exists(self, path: str) -> bool: + try: + file_system = self.get_filesystem_for_path(path) + return file_system.exists(path) + except OSError as oe: + logger.debug(f"Error in exists checking {path} {oe}") + anon_fs = self.get_filesystem(get_protocol(path), anonymous=True) + if anon_fs is not None: + logger.debug(f"Attempting anonymous exists with {anon_fs}") + return anon_fs.exists(path) + raise oe + + def get(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(from_path) + if recursive: + from_path, to_path = self.recursive_paths(from_path, to_path) + try: + if os.name == "nt" and file_system.protocol == "file" and recursive: + import shutil + + return shutil.copytree( + self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + ) + return file_system.get(from_path, to_path, recursive=recursive) + except OSError as oe: + logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") + file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) + if file_system is not None: + logger.debug(f"Attempting anonymous get with {file_system}") + return file_system.get(from_path, to_path, recursive=recursive) + raise oe + + def put(self, from_path: str, to_path: str, recursive: bool = False): + file_system = self.get_filesystem_for_path(to_path) + from_path = self.strip_file_header(from_path) + if recursive: + # Only check this for the local filesystem + if file_system.protocol == "file" and not file_system.isdir(from_path): + raise FlyteAssertion(f"Source path {from_path} is not a directory") + if os.name == "nt" and file_system.protocol == "file": + import shutil + + return shutil.copytree( + self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True + ) + from_path, to_path = self.recursive_paths(from_path, to_path) + return file_system.put(from_path, to_path, recursive=recursive) def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = None) -> str: """ @@ -391,7 +228,20 @@ def get_random_remote_path(self, file_path_or_file_name: typing.Optional[str] = Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._default_remote, file_path_or_file_name) + default_protocol = self._default_remote.protocol + if type(default_protocol) == list: + default_protocol = default_protocol[0] + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + sep = self.sep(self._default_remote) + tail = sep + tail if tail else tail + if default_protocol == "file": + # Special case the local case, users will not expect to see a file:// prefix + return self.strip_file_header(self.raw_output_prefix) + key + tail + + return self._default_remote.unstrip_protocol(self.raw_output_prefix + key + tail) def get_random_remote_directory(self): return self.get_random_remote_path(None) @@ -400,19 +250,19 @@ def get_random_local_path(self, file_path_or_file_name: typing.Optional[str] = N """ Use file_path_or_file_name, when you want a random directory, but want to preserve the leaf file name """ - return self.construct_random_path(self._local, file_path_or_file_name) + key = UUID(int=random.getrandbits(128)).hex + tail = "" + if file_path_or_file_name: + _, tail = os.path.split(file_path_or_file_name) + if tail: + return os.path.join(self._local_sandbox_dir, key, tail) + return os.path.join(self._local_sandbox_dir, key) def get_random_local_directory(self) -> str: _dir = self.get_random_local_path(None) pathlib.Path(_dir).mkdir(parents=True, exist_ok=True) return _dir - def exists(self, path: str) -> bool: - """ - checks if the given path exists - """ - return DataPersistencePlugins.find_plugin(path)().exists(path) - def download_directory(self, remote_path: str, local_path: str): """ Downloads directory from given remote to local path @@ -439,39 +289,34 @@ def upload_directory(self, local_path: str, remote_path: str): """ return self.put_data(local_path, remote_path, is_multipart=True) - def get_data(self, remote_path: str, local_path: str, is_multipart=False): + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False): """ - :param Text remote_path: - :param Text local_path: - :param bool is_multipart: + :param remote_path: + :param local_path: + :param is_multipart: """ try: with PerformanceTimer(f"Copying ({remote_path} -> {local_path})"): pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - data_persistence_plugin = DataPersistencePlugins.find_plugin(remote_path) - data_persistence_plugin(data_config=self.data_config).get( - remote_path, local_path, recursive=is_multipart - ) + self.get(remote_path, to_path=local_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" f"Original exception: {str(ex)}" ) - def put_data(self, local_path: str, remote_path: str, is_multipart=False): + def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart: bool = False): """ The implication here is that we're always going to put data to the remote location, so we .remote to ensure we don't use the true local proxy if the remote path is a file:// - :param Text local_path: - :param Text remote_path: - :param bool is_multipart: + :param local_path: + :param remote_path: + :param is_multipart: """ try: with PerformanceTimer(f"Writing ({local_path} -> {remote_path})"): - DataPersistencePlugins.find_plugin(remote_path)(data_config=self.data_config).put( - local_path, remote_path, recursive=is_multipart - ) + self.put(cast(str, local_path), remote_path, recursive=is_multipart) except Exception as ex: raise FlyteAssertion( f"Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" @@ -479,9 +324,6 @@ def put_data(self, local_path: str, remote_path: str, is_multipart=False): ) from ex -DataPersistencePlugins.register_plugin("file://", DiskPersistence) -DataPersistencePlugins.register_plugin("/", DiskPersistence) - flyte_tmp_dir = tempfile.mkdtemp(prefix="flyte-") default_local_file_access_provider = FileAccessProvider( local_sandbox_dir=os.path.join(flyte_tmp_dir, "sandbox"), diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index cec59e7318..45ee4efa51 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -10,6 +10,11 @@ OUTPUT_DIR_JUPYTER_PREFIX = "jupyter" DECK_FILE_NAME = "deck.html" +try: + from IPython.core.display import HTML +except ImportError: + ... + class Deck: """ @@ -100,8 +105,6 @@ def _get_deck( deck_map = {deck.name: deck.html for deck in new_user_params.decks} raw_html = template.render(metadata=deck_map) if not ignore_jupyter and _ipython_check(): - from IPython.core.display import HTML - return HTML(raw_html) return raw_html diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index f6635a4a57..7223d13523 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -29,8 +29,6 @@ PythonCustomizedContainerTask ExecutableTemplateShimTask ShimTaskExecutor - DataPersistence - DataPersistencePlugins """ from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -39,7 +37,7 @@ from flytekit.core.base_task import IgnoreOutputs, PythonTask, TaskResolverMixin from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.context_manager import ExecutionState, SecretsManager -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.interface import Interface from flytekit.core.promise import Promise from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask diff --git a/flytekit/extras/persistence/__init__.py b/flytekit/extras/persistence/__init__.py deleted file mode 100644 index a677632fd8..0000000000 --- a/flytekit/extras/persistence/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -======================= -DataPersistence Extras -======================= - -.. currentmodule:: flytekit.extras.persistence - -This module provides some default implementations of :py:class:`flytekit.DataPersistence`. These implementations -use command-line clients to download and upload data. The actual binaries need to be installed for these extras to work. -The binaries are not bundled with flytekit to keep it lightweight. - -Persistence Extras -=================== - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - GCSPersistence - HttpPersistence - S3Persistence -""" - -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.http import HttpPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence diff --git a/flytekit/extras/persistence/gcs_gsutil.py b/flytekit/extras/persistence/gcs_gsutil.py deleted file mode 100644 index 0ddb600024..0000000000 --- a/flytekit/extras/persistence/gcs_gsutil.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import posixpath -import typing -from shutil import which as shell_which - -from flytekit.configuration import DataConfig, GCSConfig -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.tools import subprocess - - -def _update_cmd_config_and_execute(cmd): - env = os.environ.copy() - return subprocess.check_call(cmd, env=env) - - -def _amend_path(path): - return posixpath.join(path, "*") if not path.endswith("*") else path - - -class GCSPersistence(DataPersistence): - """ - This DataPersistence plugin uses a preinstalled GSUtil binary in the container to download and upload data. - - The binary can be installed in multiple ways including simply, - - .. prompt:: - - pip install gsutil - - """ - - _GS_UTIL_CLI = "gsutil" - PROTOCOL = "gs://" - - def __init__(self, default_prefix: typing.Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super(GCSPersistence, self).__init__(name="gcs-gsutil", default_prefix=default_prefix) - self.gcs_cfg = data_config.gcs if data_config else GCSConfig.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the `gsutil` cli is present - """ - if not shell_which(GCSPersistence._GS_UTIL_CLI): - raise FlyteUserException("gsutil (gcloud cli) not found! Please install using `pip install gsutil`.") - - def _maybe_with_gsutil_parallelism(self, *gsutil_args): - """ - Check if we should run `gsutil` with the `-m` flag that enables - parallelism via multiple threads/processes. Additional tweaking of - this behavior can be achieved via the .boto configuration file. See: - https://cloud.google.com/storage/docs/boto-gsutil - """ - cmd = [GCSPersistence._GS_UTIL_CLI] - if self.gcs_cfg.gsutil_parallelism: - cmd.append("-m") - cmd.extend(gsutil_args) - - return cmd - - def exists(self, remote_path): - """ - :param Text remote_path: remote gs:// path - :rtype bool: whether the gs file exists or not - """ - GCSPersistence._check_binary() - - if not remote_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - cmd = [GCSPersistence._GS_UTIL_CLI, "-q", "stat", remote_path] - try: - _update_cmd_config_and_execute(cmd) - return True - except Exception: - return False - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if not from_path.startswith("gs://"): - raise ValueError("Not an GS Key. Please use FQN (GS ARN) of the format gs://...") - - GCSPersistence._check_binary() - if recursive: - cmd = self._maybe_with_gsutil_parallelism("cp", "-r", _amend_path(from_path), to_path) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - - return _update_cmd_config_and_execute(cmd) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - GCSPersistence._check_binary() - - if recursive: - cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(from_path), - to_path if to_path.endswith("/") else to_path + "/", - ) - else: - cmd = self._maybe_with_gsutil_parallelism("cp", from_path, to_path) - return _update_cmd_config_and_execute(cmd) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(GCSPersistence.PROTOCOL, GCSPersistence) diff --git a/flytekit/extras/persistence/http.py b/flytekit/extras/persistence/http.py deleted file mode 100644 index ce6079300d..0000000000 --- a/flytekit/extras/persistence/http.py +++ /dev/null @@ -1,84 +0,0 @@ -import base64 -import os -import pathlib - -import requests - -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions import user -from flytekit.loggers import logger -from flytekit.tools import script_mode - - -class HttpPersistence(DataPersistence): - """ - DataPersistence implementation for the HTTP protocol. only supports downloading from an http source. Uploads are - not supported currently. - """ - - PROTOCOL_HTTP = "http" - PROTOCOL_HTTPS = "https" - _HTTP_OK = 200 - _HTTP_FORBIDDEN = 403 - _HTTP_NOT_FOUND = 404 - ALLOWED_CODES = { - _HTTP_OK, - _HTTP_NOT_FOUND, - _HTTP_FORBIDDEN, - } - - def __init__(self, *args, **kwargs): - super(HttpPersistence, self).__init__(name="http/https", *args, **kwargs) - - def exists(self, path: str): - rsp = requests.head(path) - if rsp.status_code not in self.ALLOWED_CODES: - raise user.FlyteValueException( - rsp.status_code, - f"Data at {path} could not be checked for existence. Expected one of: {self.ALLOWED_CODES}", - ) - return rsp.status_code == self._HTTP_OK - - def get(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Reading data recursively from HTTP endpoint is not currently supported.") - rsp = requests.get(from_path) - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), - ) - head, _ = os.path.split(to_path) - if head and head.startswith("/"): - logger.debug(f"HttpPersistence creating {head} so that parent dirs exist") - pathlib.Path(head).mkdir(parents=True, exist_ok=True) - with open(to_path, "wb") as writer: - writer.write(rsp.content) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - if recursive: - raise user.FlyteAssertion("Recursive writing data to HTTP endpoint is not currently supported.") - - md5, _ = script_mode.hash_file(from_path) - encoded_md5 = base64.b64encode(md5) - with open(from_path, "+rb") as local_file: - content = local_file.read() - content_length = len(content) - rsp = requests.put( - to_path, data=content, headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5} - ) - - if rsp.status_code != self._HTTP_OK: - raise user.FlyteValueException( - rsp.status_code, - f"Request to send data {to_path} failed.", - ) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - raise user.FlyteAssertion( - "There are multiple ways of creating http links / paths, this is not supported by the persistence layer" - ) - - -DataPersistencePlugins.register_plugin("http://", HttpPersistence) -DataPersistencePlugins.register_plugin("https://", HttpPersistence) diff --git a/flytekit/extras/persistence/s3_awscli.py b/flytekit/extras/persistence/s3_awscli.py deleted file mode 100644 index 0b00227ca0..0000000000 --- a/flytekit/extras/persistence/s3_awscli.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -import os as _os -import re as _re -import string as _string -import time -import typing -from shutil import which as shell_which -from typing import Dict, List, Optional - -from flytekit.configuration import DataConfig, S3Config -from flytekit.core.data_persistence import DataPersistence, DataPersistencePlugins -from flytekit.exceptions.user import FlyteUserException -from flytekit.loggers import logger -from flytekit.tools import subprocess - -S3_ANONYMOUS_FLAG = "--no-sign-request" -S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" -S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" - - -def _update_cmd_config_and_execute(s3_cfg: S3Config, cmd: List[str]): - env = _os.environ.copy() - - if s3_cfg.enable_debug: - cmd.insert(1, "--debug") - - if s3_cfg.endpoint is not None: - cmd.insert(1, s3_cfg.endpoint) - cmd.insert(1, "--endpoint-url") - - if S3_ACCESS_KEY_ID_ENV_NAME not in os.environ: - if s3_cfg.access_key_id: - env[S3_ACCESS_KEY_ID_ENV_NAME] = s3_cfg.access_key_id - - if S3_SECRET_ACCESS_KEY_ENV_NAME not in os.environ: - if s3_cfg.secret_access_key: - env[S3_SECRET_ACCESS_KEY_ENV_NAME] = s3_cfg.secret_access_key - - retry = 0 - while True: - try: - try: - return subprocess.check_call(cmd, env=env) - except Exception as e: - if retry > 0: - logger.info(f"AWS command failed with error {e}, command: {cmd}, retry {retry}") - - logger.debug(f"Appending anonymous flag and retrying command {cmd}") - anonymous_cmd = cmd[:] # strings only, so this is deep enough - anonymous_cmd.insert(1, S3_ANONYMOUS_FLAG) - return subprocess.check_call(anonymous_cmd, env=env) - - except Exception as e: - logger.error(f"Exception when trying to execute {cmd}, reason: {str(e)}") - retry += 1 - if retry > s3_cfg.retries: - raise - secs = s3_cfg.backoff - logger.info(f"Sleeping before retrying again, after {secs.total_seconds()} seconds") - time.sleep(secs.total_seconds()) - logger.info("Retrying again") - - -def _extra_args(extra_args: Dict[str, str]) -> List[str]: - cmd = [] - if "ContentType" in extra_args: - cmd += ["--content-type", extra_args["ContentType"]] - if "ContentEncoding" in extra_args: - cmd += ["--content-encoding", extra_args["ContentEncoding"]] - if "ACL" in extra_args: - cmd += ["--acl", extra_args["ACL"]] - return cmd - - -class S3Persistence(DataPersistence): - """ - DataPersistence plugin for AWS S3 (and Minio). Use aws cli to manage the transfer. The binary needs to be installed - separately - - .. prompt:: - - pip install awscli - - """ - - PROTOCOL = "s3://" - _AWS_CLI = "aws" - _SHARD_CHARACTERS = [str(x) for x in range(10)] + list(_string.ascii_lowercase) - - def __init__(self, default_prefix: Optional[str] = None, data_config: typing.Optional[DataConfig] = None): - super().__init__(name="awscli-s3", default_prefix=default_prefix) - self.s3_cfg = data_config.s3 if data_config else S3Config.auto() - - @staticmethod - def _check_binary(): - """ - Make sure that the AWS cli is present - """ - if not shell_which(S3Persistence._AWS_CLI): - raise FlyteUserException("AWS CLI not found! Please install it with `pip install awscli`.") - - @staticmethod - def _split_s3_path_to_bucket_and_key(path: str) -> typing.Tuple[str, str]: - """ - splits a valid s3 uri into bucket and key - """ - path = path[len("s3://") :] - first_slash = path.index("/") - return path[:first_slash], path[first_slash + 1 :] - - def exists(self, remote_path): - """ - Given a remote path of the format s3://, checks if the remote file exists - """ - S3Persistence._check_binary() - - if not remote_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - bucket, file_path = self._split_s3_path_to_bucket_and_key(remote_path) - cmd = [ - S3Persistence._AWS_CLI, - "s3api", - "head-object", - "--bucket", - bucket, - "--key", - file_path, - ] - try: - _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - return True - except Exception as ex: - # The s3api command returns an error if the object does not exist. The error message contains - # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" - # This is a best effort for returning if the object does not exist by searching - # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search("(404)", str(ex)): - return False - else: - raise ex - - def get(self, from_path: str, to_path: str, recursive: bool = False): - S3Persistence._check_binary() - - if not from_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - if recursive: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", "--recursive", from_path, to_path] - else: - cmd = [S3Persistence._AWS_CLI, "s3", "cp", from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def put(self, from_path: str, to_path: str, recursive: bool = False): - extra_args = { - "ACL": "bucket-owner-full-control", - } - - if not to_path.startswith("s3://"): - raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") - - S3Persistence._check_binary() - cmd = [S3Persistence._AWS_CLI, "s3", "cp"] - if recursive: - cmd += ["--recursive"] - cmd.extend(_extra_args(extra_args)) - cmd += [from_path, to_path] - return _update_cmd_config_and_execute(cmd=cmd, s3_cfg=self.s3_cfg) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths: str) -> str: - paths = list(paths) # make type check happy - if add_prefix: - paths.insert(0, self.default_prefix) - path = "/".join(paths) - if add_protocol: - return f"{self.PROTOCOL}{path}" - return path - - -DataPersistencePlugins.register_plugin(S3Persistence.PROTOCOL, S3Persistence) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 03cc9a66e9..37baacef70 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -6,17 +6,20 @@ from __future__ import annotations import base64 -import functools import hashlib +import importlib import os import pathlib +import tempfile import time import typing import uuid +from base64 import b64encode from collections import OrderedDict from dataclasses import asdict, dataclass from datetime import datetime, timedelta +import requests from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest from flyteidl.core import literals_pb2 as literals_pb2 @@ -31,10 +34,15 @@ from flytekit.core.launch_plan import LaunchPlan from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceSpec +from flytekit.core.tracker import get_full_module_path from flytekit.core.type_engine import LiteralsResolver, TypeEngine from flytekit.core.workflow import WorkflowBase from flytekit.exceptions import user as user_exceptions -from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException +from flytekit.exceptions.user import ( + FlyteEntityAlreadyExistsException, + FlyteEntityNotExistException, + FlyteValueException, +) from flytekit.loggers import remote_logger from flytekit.models import common as common_models from flytekit.models import filters as filter_models @@ -62,7 +70,7 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.tools.fast_registration import fast_package -from flytekit.tools.script_mode import fast_register_single_script, hash_file +from flytekit.tools.script_mode import compress_single_script, hash_file from flytekit.tools.translator import ( FlyteControlPlaneEntity, FlyteLocalEntity, @@ -728,7 +736,23 @@ def _upload_file( content_md5=md5_bytes, filename=to_upload.name, ) - self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url) + + encoded_md5 = b64encode(md5_bytes) + with open(str(to_upload), "+rb") as local_file: + content = local_file.read() + content_length = len(content) + rsp = requests.put( + upload_location.signed_url, + data=content, + headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5}, + ) + + if rsp.status_code != requests.codes["OK"]: + raise FlyteValueException( + rsp.status_code, + f"Request to send data {upload_location.signed_url} failed.", + ) + remote_logger.debug( f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}" ) @@ -795,16 +819,14 @@ def register_script( if image_config is None: image_config = ImageConfig.auto_default_image() - upload_location, md5_bytes = fast_register_single_script( - source_path, - module_name, - functools.partial( - self.client.get_upload_signed_url, - project=project or self.default_project, - domain=domain or self.default_domain, - filename="scriptmode.tar.gz", - ), - ) + with tempfile.TemporaryDirectory() as tmp_dir: + archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz")) + mod = importlib.import_module(module_name) + compress_single_script(source_path, str(archive_fname), get_full_module_path(mod, mod.__name__)) + md5_bytes, upload_native_url = self._upload_file( + archive_fname, project or self.default_project, domain or self.default_domain + ) + serialization_settings = SerializationSettings( project=project, domain=domain, @@ -813,7 +835,7 @@ def register_script( fast_serialization_settings=FastSerializationSettings( enabled=True, destination_dir=destination_dir, - distribution_location=upload_location.native_url, + distribution_location=upload_native_url, ), ) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 29b617824c..1f3e31a382 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -1,6 +1,5 @@ import gzip import hashlib -import importlib import os import shutil import tarfile @@ -8,11 +7,6 @@ import typing from pathlib import Path -from flyteidl.service import dataproxy_pb2 as _data_proxy_pb2 - -from flytekit.core import context_manager -from flytekit.core.tracker import get_full_module_path - def compress_single_script(source_path: str, destination: str, full_module_name: str): """ @@ -96,24 +90,6 @@ def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo: return tar_info -def fast_register_single_script( - source_path: str, module_name: str, create_upload_location_fn: typing.Callable -) -> (_data_proxy_pb2.CreateUploadLocationResponse, bytes): - - # Open a temp directory and dump the contents of the digest. - with tempfile.TemporaryDirectory() as tmp_dir: - archive_fname = os.path.join(tmp_dir, "script_mode.tar.gz") - mod = importlib.import_module(module_name) - compress_single_script(source_path, archive_fname, get_full_module_path(mod, mod.__name__)) - - flyte_ctx = context_manager.FlyteContextManager.current_context() - md5, _ = hash_file(archive_fname) - upload_location = create_upload_location_fn(content_md5=md5) - flyte_ctx.file_access.put_data(archive_fname, upload_location.signed_url) - - return upload_location, md5 - - def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str): """ Hash a file and produce a digest to be used as a version diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 39f8d11e24..ae3e8a00d9 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -1,12 +1,18 @@ import os import typing +from pathlib import Path from typing import TypeVar import pandas as pd import pyarrow as pa import pyarrow.parquet as pq +from botocore.exceptions import NoCredentialsError +from fsspec.core import split_protocol, strip_protocol +from fsspec.utils import get_protocol -from flytekit import FlyteContext +from flytekit import FlyteContext, logger +from flytekit.configuration import DataConfig +from flytekit.core.data_persistence import s3_setup_args from flytekit.deck import TopFrameRenderer from flytekit.deck.renderer import ArrowRenderer from flytekit.models import literals @@ -23,6 +29,15 @@ T = TypeVar("T") +def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: + protocol = get_protocol(uri) + if protocol == "s3": + kwargs = s3_setup_args(cfg.s3, anon) + if kwargs: + return kwargs + return None + + class PandasToParquetEncodingHandler(StructuredDatasetEncoder): def __init__(self): super().__init__(pd.DataFrame, None, PARQUET) @@ -33,6 +48,26 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") + df = typing.cast(pd.DataFrame, structured_dataset.dataframe) + df.to_parquet( + path, + coerce_timestamps="us", + allow_truncated_timestamps=False, + storage_options=get_storage_options(ctx.file_access.data_config, path), + ) + structured_dataset_type.format = PARQUET + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + def ddencode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(pd.DataFrame, structured_dataset.dataframe) @@ -53,6 +88,24 @@ def decode( ctx: FlyteContext, flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + uri = flyte_value.uri + columns = None + kwargs = get_storage_options(ctx.file_access.data_config, uri) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + try: + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) + except NoCredentialsError: + logger.debug("S3 source detected, attempting anonymous S3 access") + kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + return pd.read_parquet(uri, columns=columns, storage_options=kwargs) + + def dcccecode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory() @@ -73,13 +126,13 @@ def encode( structured_dataset: StructuredDataset, structured_dataset_type: StructuredDatasetType, ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_path() - df = structured_dataset.dataframe - local_dir = ctx.file_access.get_random_local_directory() - local_path = os.path.join(local_dir, f"{0:05}") - pq.write_table(df, local_path) - ctx.file_access.upload_directory(local_dir, path) - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() + if not ctx.file_access.is_remote(uri): + Path(uri).mkdir(parents=True, exist_ok=True) + path = os.path.join(uri, f"{0:05}") + filesystem = ctx.file_access.get_filesystem_for_path(path) + pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) + return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): @@ -92,13 +145,23 @@ def decode( flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: - path = flyte_value.uri - local_dir = ctx.file_access.get_random_local_directory() - ctx.file_access.get_data(path, local_dir, is_multipart=True) + uri = flyte_value.uri + if not ctx.file_access.is_remote(uri): + Path(uri).parent.mkdir(parents=True, exist_ok=True) + _, path = split_protocol(uri) + + columns = None if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pq.read_table(local_dir, columns=columns) - return pq.read_table(local_dir) + try: + fs = ctx.file_access.get_filesystem_for_path(uri) + return pq.read_table(path, filesystem=fs, columns=columns) + except NoCredentialsError as e: + logger.debug("S3 source detected, attempting anonymous S3 access") + fs = ctx.file_access.get_filesystem_for_path(uri, anonymous=True) + if fs is not None: + return pq.read_table(path, filesystem=fs, columns=columns) + raise e StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 90755c8cc5..9b4951e084 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -12,11 +12,11 @@ import pandas as pd import pyarrow as pa from dataclasses_json import config, dataclass_json +from fsspec.utils import get_protocol from marshmallow import fields from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence from flytekit.core.type_engine import TypeEngine, TypeTransformer from flytekit.deck.renderer import Renderable from flytekit.loggers import logger @@ -34,6 +34,7 @@ # Storage formats PARQUET: StructuredDatasetFormat = "parquet" GENERIC_FORMAT: StructuredDatasetFormat = "" +GENERIC_PROTOCOL: str = "generic protocol" @dataclass_json @@ -74,6 +75,7 @@ def __init__( self._literal_sd: Optional[literals.StructuredDataset] = None # Not meant for users to set, will be set by an open() call self._dataframe_type: Optional[DF] = None # type: ignore + self._already_uploaded = False @property def dataframe(self) -> Optional[DF]: @@ -270,11 +272,6 @@ def decode( raise NotImplementedError -def protocol_prefix(uri: str) -> str: - p = DataPersistencePlugins.get_protocol(uri) - return p - - def convert_schema_type_to_structured_dataset_type( column_type: int, ) -> int: @@ -336,42 +333,54 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]): @classmethod def _finder(cls, handler_map, df_type: Type, protocol: str, format: str): - # If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler - # if missing, see if there's a generic format handler. Error if missing. - # If the incoming format requested is the generic format (""), then see if it's present, - # if not, look to see if there is a default format for the df_type and a handler for that format. - # if still missing, look to see if there's only _one_ handler for that type, if so then use that. - if format != GENERIC_FORMAT: - try: - return handler_map[df_type][protocol][format] - except KeyError: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - ... - else: - try: - return handler_map[df_type][protocol][GENERIC_FORMAT] - except KeyError: - if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]: - hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]] - logger.debug( - f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}" - f" using the generic handler {hh} instead." - ) - return hh - if len(handler_map[df_type][protocol]) == 1: - hh = list(handler_map[df_type][protocol].values())[0] - logger.debug( - f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}" - ) - return hh + # If there's an exact match, then we should use it. + try: + return handler_map[df_type][protocol][format] + except KeyError: + ... + + fsspec_handler = None + protocol_specific_handler = None + single_handler = None + default_format = cls.DEFAULT_FORMATS.get(df_type, None) + + try: + fss_handlers = handler_map[df_type]["fsspec"] + if format in fss_handlers: + fsspec_handler = fss_handlers[format] + elif GENERIC_FORMAT in fss_handlers: + fsspec_handler = fss_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in fss_handlers and format == GENERIC_FORMAT: + fsspec_handler = fss_handlers[default_format] else: - logger.warning( - f"Did not automatically pick a handler for {df_type}," - f" more than one detected {handler_map[df_type][protocol].keys()}" - ) - raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") + if len(fss_handlers) == 1 and format == GENERIC_FORMAT: + single_handler = list(fss_handlers.values())[0] + else: + ... + except KeyError: + ... + + try: + protocol_handlers = handler_map[df_type][protocol] + if GENERIC_FORMAT in protocol_handlers: + protocol_specific_handler = protocol_handlers[GENERIC_FORMAT] + else: + if default_format and default_format in protocol_handlers: + protocol_specific_handler = protocol_handlers[default_format] + else: + if len(protocol_handlers) == 1: + single_handler = list(protocol_handlers.values())[0] + else: + ... + + except KeyError: + ... + + if protocol_specific_handler or fsspec_handler or single_handler: + return protocol_specific_handler or fsspec_handler or single_handler + else: + raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|") @classmethod def get_encoder(cls, df_type: Type, protocol: str, format: str): @@ -436,18 +445,12 @@ def register( if h.protocol is None: if default_for_type: raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.") - for persistence_protocol in DataPersistencePlugins.supported_protocols(): - # TODO: Clean this up when we get to replacing the persistence layer. - # The behavior of the protocols given in the supported_protocols and is_supported_protocol - # is not actually the same as the one returned in get_protocol. - stripped = DataPersistencePlugins.get_protocol(persistence_protocol) - logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}") - try: - cls.register_for_protocol( - h, stripped, False, override, default_format_for_type, default_storage_for_type - ) - except DuplicateHandlerError: - logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate") + try: + cls.register_for_protocol( + h, "fsspec", False, override, default_format_for_type, default_storage_for_type + ) + except DuplicateHandlerError: + logger.debug(f"Skipping generic fsspec protocol for handler {h} because duplicate") elif h.protocol == "": raise ValueError(f"Use None instead of empty string for registering handler {h}") @@ -470,8 +473,7 @@ def register_for_protocol( See the main register function instead. """ if protocol == "/": - # TODO: Special fix again, because get_protocol returns file, instead of file:// - protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL) + protocol = "file" lowest_level = cls._handler_finder(h, protocol) if h.supported_format in lowest_level and override is False: raise DuplicateHandlerError( @@ -542,6 +544,8 @@ def to_literal( # def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]: # return dataset if python_val._literal_sd is not None: + if python_val._already_uploaded: + return Literal(scalar=Scalar(structured_dataset=python_val._literal_sd)) if python_val.dataframe is not None: raise ValueError( f"Shouldn't have specified both literal {python_val._literal_sd} and dataframe {python_val.dataframe}" @@ -593,7 +597,7 @@ def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: O if df_type in self.DEFAULT_PROTOCOLS: return self.DEFAULT_PROTOCOLS[df_type] else: - protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix) + protocol = get_protocol(uri or ctx.file_access.raw_output_prefix) logger.debug( f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}" ) @@ -622,7 +626,10 @@ def encode( # Note that this will always be the same as the incoming format except for when the fallback handler # with a format of "" is used. sd_model.metadata._structured_dataset_type.format = handler.supported_format - return Literal(scalar=Scalar(structured_dataset=sd_model)) + lit = Literal(scalar=Scalar(structured_dataset=sd_model)) + sd._literal_sd = sd_model + sd._already_uploaded = True + return lit def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset @@ -769,7 +776,7 @@ def open_as( :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. :return: dataframe. It could be pandas dataframe or arrow table, etc. """ - protocol = protocol_prefix(sd.uri) + protocol = get_protocol(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): @@ -783,7 +790,7 @@ def iter_as( df_type: Type[DF], updated_metadata: StructuredDatasetMetadata, ) -> typing.Iterator[DF]: - protocol = protocol_prefix(sd.uri) + protocol = get_protocol(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] result: Union[DF, typing.Iterator[DF]] = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py index 68ee456ed6..e69de29bb2 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/__init__.py @@ -1,53 +0,0 @@ -""" -.. currentmodule:: flytekitplugins.fsspec - -This package contains things that are useful when extending Flytekit. - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - ArrowToParquetEncodingHandler - FSSpecPersistence - PandasToParquetEncodingHandler - ParquetToArrowDecodingHandler - ParquetToPandasDecodingHandler -""" - -__all__ = [ - "ArrowToParquetEncodingHandler", - "FSSpecPersistence", - "PandasToParquetEncodingHandler", - "ParquetToArrowDecodingHandler", - "ParquetToPandasDecodingHandler", -] - -import importlib - -from flytekit import StructuredDatasetTransformerEngine, logger - -from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler -from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler -from .persist import FSSpecPersistence - -S3 = "s3" -ABFS = "abfs" -GCS = "gs" - - -def _register(protocol: str): - logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.") - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), True, True) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), True, True) - - -if importlib.util.find_spec("adlfs"): - _register(ABFS) - -if importlib.util.find_spec("s3fs"): - _register(S3) - -if importlib.util.find_spec("gcsfs"): - _register(GCS) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py deleted file mode 100644 index ec8d5f975e..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import typing -from pathlib import Path - -import pyarrow as pa -import pyarrow.parquet as pq -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence -from fsspec.core import split_protocol, strip_protocol - -from flytekit import FlyteContext, logger -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -class ArrowToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - filesystem = fp.get_filesystem(path) - pq.write_table(structured_dataset.dataframe, strip_protocol(path), filesystem=filesystem) - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToArrowDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pa.Table, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pa.Table: - uri = flyte_value.uri - if not ctx.file_access.is_remote(uri): - Path(uri).parent.mkdir(parents=True, exist_ok=True) - _, path = split_protocol(uri) - - columns = None - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - fp = FSSpecPersistence(data_config=ctx.file_access.data_config) - fs = fp.get_filesystem(uri) - return pq.read_table(path, filesystem=fs, columns=columns) - except NoCredentialsError as e: - logger.debug("S3 source detected, attempting anonymous S3 access") - fs = FSSpecPersistence.get_anonymous_filesystem(uri) - if fs is not None: - return pq.read_table(path, filesystem=fs, columns=columns) - raise e diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py deleted file mode 100644 index e4986ed9f6..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ /dev/null @@ -1,76 +0,0 @@ -import os -import typing -from pathlib import Path - -import pandas as pd -from botocore.exceptions import NoCredentialsError -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args - -from flytekit import FlyteContext, logger -from flytekit.configuration import DataConfig -from flytekit.models import literals -from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.structured_dataset import ( - PARQUET, - StructuredDataset, - StructuredDatasetDecoder, - StructuredDatasetEncoder, -) - - -def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]: - protocol = FSSpecPersistence.get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3) - if kwargs: - return kwargs - return None - - -class PandasToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() - if not ctx.file_access.is_remote(uri): - Path(uri).mkdir(parents=True, exist_ok=True) - path = os.path.join(uri, f"{0:05}") - df = typing.cast(pd.DataFrame, structured_dataset.dataframe) - df.to_parquet( - path, - coerce_timestamps="us", - allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), - ) - structured_dataset_type.format = PARQUET - return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToPandasDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(pd.DataFrame, protocol, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> pd.DataFrame: - uri = flyte_value.uri - columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - try: - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) - except NoCredentialsError: - logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs["anon"] = True - return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py deleted file mode 100644 index b890b3cc6c..0000000000 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/persist.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import typing - -import fsspec -from fsspec.registry import known_implementations - -from flytekit.configuration import DataConfig, S3Config -from flytekit.extend import DataPersistence, DataPersistencePlugins -from flytekit.loggers import logger - -# Refer to https://github.com/fsspec/s3fs/blob/50bafe4d8766c3b2a4e1fc09669cf02fb2d71454/s3fs/core.py#L198 -# for key and secret -_FSSPEC_S3_KEY_ID = "key" -_FSSPEC_S3_SECRET = "secret" - - -def s3_setup_args(s3_cfg: S3Config): - kwargs = {} - if s3_cfg.access_key_id: - kwargs[_FSSPEC_S3_KEY_ID] = s3_cfg.access_key_id - - if s3_cfg.secret_access_key: - kwargs[_FSSPEC_S3_SECRET] = s3_cfg.secret_access_key - - # S3fs takes this as a special arg - if s3_cfg.endpoint is not None: - kwargs["client_kwargs"] = {"endpoint_url": s3_cfg.endpoint} - - return kwargs - - -class FSSpecPersistence(DataPersistence): - """ - This DataPersistence plugin uses fsspec to perform the IO. - NOTE: The put is not as performant as it can be for multiple files because of - - https://github.com/intake/filesystem_spec/issues/724. Once this bug is fixed, we can remove the `HACK` in the put - method - """ - - def __init__(self, default_prefix=None, data_config: typing.Optional[DataConfig] = None): - super(FSSpecPersistence, self).__init__(name="fsspec-persistence", default_prefix=default_prefix) - self.default_protocol = self.get_protocol(default_prefix) - self._data_cfg = data_config if data_config else DataConfig.auto() - - @staticmethod - def get_protocol(path: typing.Optional[str] = None): - if path: - return DataPersistencePlugins.get_protocol(path) - logger.info("Setting protocol to file") - return "file" - - def get_filesystem(self, path: str) -> fsspec.AbstractFileSystem: - protocol = FSSpecPersistence.get_protocol(path) - kwargs = {} - if protocol == "file": - kwargs = {"auto_mkdir": True} - elif protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - return fsspec.filesystem(protocol, **kwargs) # type: ignore - - def get_anonymous_filesystem(self, path: str) -> typing.Optional[fsspec.AbstractFileSystem]: - protocol = FSSpecPersistence.get_protocol(path) - if protocol == "s3": - kwargs = s3_setup_args(self._data_cfg.s3) - anonymous_fs = fsspec.filesystem(protocol, anon=True, **kwargs) # type: ignore - return anonymous_fs - return None - - @staticmethod - def recursive_paths(f: str, t: str) -> typing.Tuple[str, str]: - if not f.endswith("*"): - f = os.path.join(f, "*") - if not t.endswith("/"): - t += "/" - return f, t - - def exists(self, path: str) -> bool: - try: - fs = self.get_filesystem(path) - return fs.exists(path) - except OSError as oe: - logger.debug(f"Error in exists checking {path} {oe}") - fs = self.get_anonymous_filesystem(path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 exists check") - return fs.exists(path) - raise oe - - def get(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(from_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - try: - return fs.get(from_path, to_path, recursive=recursive) - except OSError as oe: - logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") - fs = self.get_anonymous_filesystem(from_path) - if fs is not None: - logger.debug("S3 source detected, attempting anonymous S3 access") - return fs.get(from_path, to_path, recursive=recursive) - raise oe - - def put(self, from_path: str, to_path: str, recursive: bool = False): - fs = self.get_filesystem(to_path) - if recursive: - from_path, to_path = self.recursive_paths(from_path, to_path) - # BEGIN HACK! - # Once https://github.com/intake/filesystem_spec/issues/724 is fixed, delete the special recursive handling - from fsspec.implementations.local import LocalFileSystem - from fsspec.utils import other_paths - - lfs = LocalFileSystem() - try: - lpaths = lfs.expand_path(from_path, recursive=recursive) - except FileNotFoundError: - # In some cases, there is no file in the original directory, so we just skip copying the file to the remote path - logger.debug(f"there is no file in the {from_path}") - return - rpaths = other_paths(lpaths, to_path) - for l, r in zip(lpaths, rpaths): - fs.put_file(l, r) - return - # END OF HACK!! - return fs.put(from_path, to_path, recursive=recursive) - - def construct_path(self, add_protocol: bool, add_prefix: bool, *paths) -> str: - path_list = list(paths) # make type check happy - if add_prefix: - path_list.insert(0, self.default_prefix) # type: ignore - path = "/".join(path_list) - if add_protocol: - return f"{self.default_protocol}://{path}" - return typing.cast(str, path) - - -def _register(): - logger.info("Registering fsspec known implementations and overriding all default implementations for persistence.") - DataPersistencePlugins.register_plugin("/", FSSpecPersistence, force=True) - for k, v in known_implementations.items(): - DataPersistencePlugins.register_plugin(f"{k}://", FSSpecPersistence, force=True) - - -# Registering all plugins -_register() diff --git a/plugins/flytekit-data-fsspec/setup.py b/plugins/flytekit-data-fsspec/setup.py index a7920d1eeb..0ceae3ac1b 100644 --- a/plugins/flytekit-data-fsspec/setup.py +++ b/plugins/flytekit-data-fsspec/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-data-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "fsspec<=2023.1", "botocore>=1.7.48", "pandas>=1.2.0"] +plugin_requires = [] __version__ = "0.0.0+develop" @@ -13,7 +13,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="This package data-plugins for flytekit, that are powered by fsspec", + description="This is a deprecated plugin as of flytekit 1.5", url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-data-fsspec", long_description=open("README.md").read(), long_description_content_type="text/markdown", @@ -22,9 +22,9 @@ install_requires=plugin_requires, extras_require={ # https://github.com/fsspec/filesystem_spec/blob/master/setup.py#L36 - "abfs": ["adlfs>=2022.2.0"], - "aws": ["s3fs>=2021.7.0"], - "gcp": ["gcsfs>=2021.7.0"], + "abfs": [], + "aws": [], + "gcp": [], }, license="apache2", python_requires=">=3.8", @@ -41,5 +41,4 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py b/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py deleted file mode 100644 index 434a763a93..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_basic_dfs.py +++ /dev/null @@ -1,44 +0,0 @@ -import pandas as pd -import pyarrow as pa -from flytekitplugins.fsspec.pandas import get_storage_options - -from flytekit import kwtypes, task -from flytekit.configuration import DataConfig, S3Config - -try: - from typing import Annotated -except ImportError: - from typing_extensions import Annotated - - -def test_get_storage_options(): - endpoint = "https://s3.amazonaws.com" - - options = get_storage_options(DataConfig(s3=S3Config(endpoint=endpoint)), "s3://bucket/somewhere") - assert options == {"client_kwargs": {"endpoint_url": endpoint}} - - options = get_storage_options(DataConfig(), "/tmp/file") - assert options is None - - -cols = kwtypes(Name=str, Age=int) -subset_cols = kwtypes(Name=str) - - -@task -def t1( - df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] -) -> (Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]): - return df1, df2 - - -def test_structured_dataset_wf(): - pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - pa_df = pa.Table.from_pandas(pd_df) - - subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) - subset_pa_df = pa.Table.from_pandas(subset_pd_df) - - df1, df2 = t1(df1=pd_df, df2=pa_df) - assert df1.equals(subset_pd_df) - assert df2.equals(subset_pa_df) diff --git a/plugins/flytekit-data-fsspec/tests/test_persist.py b/plugins/flytekit-data-fsspec/tests/test_persist.py deleted file mode 100644 index 8e87c9c5eb..0000000000 --- a/plugins/flytekit-data-fsspec/tests/test_persist.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -import pathlib -import tempfile - -import mock -from flytekitplugins.fsspec.persist import FSSpecPersistence, s3_setup_args -from fsspec.implementations.local import LocalFileSystem - -from flytekit.configuration import S3Config - - -def test_s3_setup_args(): - kwargs = s3_setup_args(S3Config()) - assert kwargs == {} - - kwargs = s3_setup_args(S3Config(endpoint="http://localhost:30084")) - assert kwargs == {"client_kwargs": {"endpoint_url": "http://localhost:30084"}} - - kwargs = s3_setup_args(S3Config(access_key_id="access")) - assert kwargs == {"key": "access"} - - -@mock.patch.dict(os.environ, {}, clear=True) -def test_s3_setup_args_env_empty(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_both(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "FLYTE_AWS_ACCESS_KEY_ID": "flyte", - "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", - }, - clear=True, -) -def test_s3_setup_args_env_flyte(): - kwargs = s3_setup_args(S3Config.auto()) - assert kwargs == {"key": "flyte", "secret": "flyte-secret"} - - -@mock.patch.dict( - os.environ, - { - "AWS_ACCESS_KEY_ID": "ignore-user", - "AWS_SECRET_ACCESS_KEY": "ignore-secret", - }, - clear=True, -) -def test_s3_setup_args_env_aws(): - kwargs = s3_setup_args(S3Config.auto()) - # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default - assert kwargs == {} - - -def test_get_protocol(): - assert FSSpecPersistence.get_protocol("s3://abc") == "s3" - assert FSSpecPersistence.get_protocol("/abc") == "file" - assert FSSpecPersistence.get_protocol("file://abc") == "file" - assert FSSpecPersistence.get_protocol("gs://abc") == "gs" - assert FSSpecPersistence.get_protocol("sftp://abc") == "sftp" - assert FSSpecPersistence.get_protocol("abfs://abc") == "abfs" - - -def test_get_anonymous_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_anonymous_filesystem("/abc") - assert fs is None - fs = fp.get_anonymous_filesystem("s3://abc") - assert fs is not None - assert fs.protocol == ["s3", "s3a"] - - -def test_get_filesystem(): - fp = FSSpecPersistence() - fs = fp.get_filesystem("/abc") - assert fs is not None - assert isinstance(fs, LocalFileSystem) - - -def test_recursive_paths(): - f, t = FSSpecPersistence.recursive_paths("/tmp", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/", "/tmp/") - assert (f, t) == ("/tmp/*", "/tmp/") - f, t = FSSpecPersistence.recursive_paths("/tmp/*", "/tmp") - assert (f, t) == ("/tmp/*", "/tmp/") - - -def test_exists(): - fs = FSSpecPersistence() - assert not fs.exists("/tmp/non-existent") - - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - assert fs.exists(f) - - -def test_get(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.get(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_get_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.get(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - f = os.path.join(tdir, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - t = os.path.join(tdir, "t.txt") - - fs.put(f, t) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_put_recursive(): - fs = FSSpecPersistence() - with tempfile.TemporaryDirectory() as tdir: - p = pathlib.Path(tdir) - d = p.joinpath("d") - d.mkdir() - f = d.joinpath(d, "f.txt") - with open(f, "w") as fp: - fp.write("hello") - - o = p.joinpath("o") - - t = o.joinpath(o, "f.txt") - fs.put(str(d), str(o), recursive=True) - with open(t, "r") as fp: - assert fp.read() == "hello" - - -def test_construct_path(): - fs = FSSpecPersistence() - assert fs.construct_path(True, False, "abc") == "file://abc" diff --git a/plugins/flytekit-data-fsspec/tests/test_placeholder.py b/plugins/flytekit-data-fsspec/tests/test_placeholder.py new file mode 100644 index 0000000000..eb6dc82a34 --- /dev/null +++ b/plugins/flytekit-data-fsspec/tests/test_placeholder.py @@ -0,0 +1,3 @@ +# This test is here to give pytest something to run, otherwise it returns a non-zero return code. +def test_dummy(): + assert 1 + 1 == 2 diff --git a/plugins/flytekit-spark/tests/test_pyspark_transformers.py b/plugins/flytekit-spark/tests/test_pyspark_transformers.py index cb527e16ef..212af454dd 100644 --- a/plugins/flytekit-spark/tests/test_pyspark_transformers.py +++ b/plugins/flytekit-spark/tests/test_pyspark_transformers.py @@ -6,13 +6,24 @@ import flytekit from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.types.structured.structured_dataset import StructuredDatasetTransformerEngine def test_type_resolution(): assert type(TypeEngine.get_transformer(PipelineModel)) == PySparkPipelineModelTransformer +def test_basic_get(): + + ctx = FlyteContextManager.current_context() + e = StructuredDatasetTransformerEngine() + prot = e._protocol_from_type_or_prefix(ctx, pyspark.sql.DataFrame, uri="/tmp/blah") + en = e.get_encoder(pyspark.sql.DataFrame, prot, "") + assert en is not None + + def test_pipeline_model_compatibility(): @task(task_config=Spark()) def my_dataset() -> pyspark.sql.DataFrame: diff --git a/setup.py b/setup.py index 3e7b886e71..11a24ccbe4 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,10 @@ "grpcio>=1.50.0,<2.0", "grpcio-status>=1.50.0,<2.0", "importlib-metadata", + "fsspec>=2023.3.0", + "adlfs", + "s3fs", + "gcsfs", "pyopenssl", "joblib", "python-json-logger>=2.0.0", diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 45d50a2fc5..1a24cccb61 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -2,6 +2,7 @@ import typing from collections import OrderedDict +import fsspec import mock import pytest from flyteidl.core.errors_pb2 import ErrorDocument @@ -10,15 +11,12 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.base_task import IgnoreOutputs -from flytekit.core.data_persistence import DiskPersistence from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.promise import VoidPromise from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.scopes import system_entry_point -from flytekit.extras.persistence.gcs_gsutil import GCSPersistence -from flytekit.extras.persistence.s3_awscli import S3Persistence from flytekit.models import literals as _literal_models from flytekit.models.core import errors as error_models from flytekit.models.core import execution as execution_models @@ -311,7 +309,22 @@ def test_dispatch_execute_system_error(mock_write_to_file, mock_upload_dir, mock assert ed.error.origin == execution_models.ExecutionError.ErrorKind.SYSTEM -def test_persist_ss(): +def test_setup_disk_prefix(): + with setup_execution("qwerty") as ctx: + assert isinstance(ctx.file_access._default_remote, fsspec.AbstractFileSystem) + assert ctx.file_access._default_remote.protocol == "file" + + +def test_setup_cloud_prefix(): + with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert ctx.file_access._default_remote.protocol[0] == "s3" + + with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: + assert "gs" in ctx.file_access._default_remote.protocol + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +def test_persist_ss(mock_gcs): default_img = Image(name="default", fqn="test", tag="tag") ss = SerializationSettings( project="proj1", @@ -327,19 +340,6 @@ def test_persist_ss(): assert ctx.serialization_settings.domain == "dom" -def test_setup_disk_prefix(): - with setup_execution("qwerty") as ctx: - assert isinstance(ctx.file_access._default_remote, DiskPersistence) - - -def test_setup_cloud_prefix(): - with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, S3Persistence) - - with setup_execution("gs://", checkpoint_path=None, prev_checkpoint=None) as ctx: - assert isinstance(ctx.file_access._default_remote, GCSPersistence) - - def test_normalize_inputs(): assert normalize_inputs("{{.rawOutputDataPrefix}}", "{{.checkpointOutputPrefix}}", "{{.prevCheckpointPrefix}}") == ( None, diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index 2add1b9e7d..b5fa46fe54 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import pytest @@ -36,11 +37,14 @@ def test_sync_checkpoint_save_file(tmpdir): def test_sync_checkpoint_save_filepath(tmpdir): - td_path = Path(tmpdir) - cp = SyncCheckpoint(checkpoint_dest=tmpdir) - dst_path = td_path.joinpath("test") + src_path = Path(os.path.join(tmpdir, "src")) + src_path.mkdir(parents=True, exist_ok=True) + chkpnt_path = Path(os.path.join(tmpdir, "dest")) + chkpnt_path.mkdir() + cp = SyncCheckpoint(checkpoint_dest=str(chkpnt_path)) + dst_path = chkpnt_path.joinpath("test") assert not dst_path.exists() - inp = td_path.joinpath("test") + inp = src_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") cp.save(inp) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py new file mode 100644 index 0000000000..880036f636 --- /dev/null +++ b/tests/flytekit/unit/core/test_data.py @@ -0,0 +1,215 @@ +import os +import shutil +import tempfile + +import fsspec +import mock +import pytest + +from flytekit.configuration import Config, S3Config +from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args + +local = fsspec.filesystem("file") +root = os.path.abspath(os.sep) + + +@mock.patch("google.auth.compute_engine._metadata") # to prevent network calls +@mock.patch("flytekit.core.data_persistence.UUID") +def test_path_getting(mock_uuid_class, mock_gcs): + mock_uuid_class.return_value.hex = "abcdef123" + + # Testing with raw output prefix pointing to a local path + loc_sandbox = os.path.join(root, "tmp", "unittest") + loc_data = os.path.join(root, "tmp", "unittestdata") + local_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix=loc_data) + assert local_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert local_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert local_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + # Test local path and directory + assert local_raw_fp.get_random_local_path() == os.path.join(root, "tmp", "unittest", "local_flytekit", "abcdef123") + assert local_raw_fp.get_random_local_path("xjiosa/blah.txt") == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123", "blah.txt" + ) + assert local_raw_fp.get_random_local_directory() == os.path.join( + root, "tmp", "unittest", "local_flytekit", "abcdef123" + ) + + # Recursive paths + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" + ) + assert "file:///abc/happy/", "s3://my-s3-bucket/bucket1/" == local_raw_fp.recursive_paths( + "file:///abc/happy", "s3://my-s3-bucket/bucket1" + ) + + # Test with remote pointed to s3. + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + # trailing slash should make no difference + s3_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="s3://my-s3-bucket/") + assert s3_fa.get_random_remote_path() == "s3://my-s3-bucket/abcdef123" + assert s3_fa.get_random_remote_directory() == "s3://my-s3-bucket/abcdef123" + + # Testing with raw output prefix pointing to file:// + # Skip tests for windows + if os.name != "nt": + file_raw_fp = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="file:///tmp/unittestdata") + assert file_raw_fp.get_random_remote_path() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + assert file_raw_fp.get_random_remote_path("/fsa/blah.csv") == os.path.join( + root, "tmp", "unittestdata", "abcdef123", "blah.csv" + ) + assert file_raw_fp.get_random_remote_directory() == os.path.join(root, "tmp", "unittestdata", "abcdef123") + + g_fa = FileAccessProvider(local_sandbox_dir=loc_sandbox, raw_output_prefix="gs://my-s3-bucket/") + assert g_fa.get_random_remote_path() == "gs://my-s3-bucket/abcdef123" + + +@mock.patch("flytekit.core.data_persistence.UUID") +def test_default_file_access_instance(mock_uuid_class): + mock_uuid_class.return_value.hex = "abcdef123" + + assert default_local_file_access_provider.get_random_local_path().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + assert default_local_file_access_provider.get_random_local_path("bob.txt").endswith( + os.path.join("abcdef123", "bob.txt") + ) + + assert default_local_file_access_provider.get_random_local_directory().endswith( + os.path.join("sandbox", "local_flytekit", "abcdef123") + ) + + x = default_local_file_access_provider.get_random_remote_path() + assert x.endswith(os.path.join("raw", "abcdef123")) + x = default_local_file_access_provider.get_random_remote_path("eve.txt") + assert x.endswith(os.path.join("raw", "abcdef123", "eve.txt")) + x = default_local_file_access_provider.get_random_remote_directory() + assert x.endswith(os.path.join("raw", "abcdef123")) + + +@pytest.fixture +def source_folder(): + # Set up source directory for testing + parent_temp = tempfile.mkdtemp() + src_dir = os.path.join(parent_temp, "source", "") + nested_dir = os.path.join(src_dir, "nested") + local.mkdir(nested_dir) + local.touch(os.path.join(src_dir, "original.txt")) + local.touch(os.path.join(nested_dir, "more.txt")) + yield src_dir + shutil.rmtree(parent_temp) + + +def test_local_raw_fsspec(source_folder): + # Test copying using raw fsspec local filesystem, should not create a nested folder + with tempfile.TemporaryDirectory() as dest_tmpdir: + local.put(source_folder, dest_tmpdir, recursive=True) + + new_temp_dir_2 = tempfile.mkdtemp() + new_temp_dir_2 = os.path.join(new_temp_dir_2, "doesnotexist") + local.put(source_folder, new_temp_dir_2, recursive=True) + files = local.find(new_temp_dir_2) + assert len(files) == 2 + + +def test_local_provider(source_folder): + # Test that behavior putting from a local dir to a local remote dir is the same whether or not the local + # dest folder exists. + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as dest_tmpdir: + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=dest_tmpdir, data_config=dc) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + files = provider._default_remote.find(doesnotexist) + assert len(files) == 2 + + exists = provider.get_random_remote_directory() + provider._default_remote.mkdir(exists) + provider.put_data(source_folder, exists, is_multipart=True) + files = provider._default_remote.find(exists) + assert len(files) == 2 + + +@pytest.mark.sandbox_test +def test_s3_provider(source_folder): + # Running mkdir on s3 filesystem doesn't do anything so leaving out for now + dc = Config.for_sandbox().data_config + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix="s3://my-s3-bucket/testdata/", data_config=dc + ) + doesnotexist = provider.get_random_remote_directory() + provider.put_data(source_folder, doesnotexist, is_multipart=True) + fs = provider.get_filesystem_for_path(doesnotexist) + files = fs.find(doesnotexist) + assert len(files) == 2 + + +def test_local_provider_get_empty(): + dc = Config.for_sandbox().data_config + with tempfile.TemporaryDirectory() as empty_source: + with tempfile.TemporaryDirectory() as dest_folder: + provider = FileAccessProvider( + local_sandbox_dir="/tmp/unittest", raw_output_prefix=empty_source, data_config=dc + ) + provider.get_data(empty_source, dest_folder, is_multipart=True) + loc = provider.get_filesystem_for_path(dest_folder) + src_files = loc.find(empty_source) + assert len(src_files) == 0 + dest_files = loc.find(dest_folder) + assert len(dest_files) == 0 + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_empty(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + mock_os.get.return_value = None + s3c = S3Config.auto() + kwargs = s3_setup_args(s3c) + assert kwargs == {} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_both(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret"} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_flyte(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AWS_ACCESS_KEY_ID": "flyte", + "FLYTE_AWS_SECRET_ACCESS_KEY": "flyte-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + assert kwargs == {"key": "flyte", "secret": "flyte-secret"} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "AWS_ACCESS_KEY_ID": "ignore-user", + "AWS_SECRET_ACCESS_KEY": "ignore-secret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + kwargs = s3_setup_args(S3Config.auto()) + # not explicitly in kwargs, since fsspec/boto3 will use these env vars by default + assert kwargs == {} diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index af39e9e852..27b407c1ce 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,11 +1,11 @@ -from flytekit.core.data_persistence import DataPersistencePlugins, FileAccessProvider +from flytekit.core.data_persistence import FileAccessProvider def test_get_random_remote_path(): fp = FileAccessProvider("/tmp", "s3://my-bucket") path = fp.get_random_remote_path() assert path.startswith("s3://my-bucket") - assert fp.raw_output_prefix == "s3://my-bucket" + assert fp.raw_output_prefix == "s3://my-bucket/" def test_is_remote(): @@ -14,10 +14,3 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True - - -def test_lister(): - x = DataPersistencePlugins.supported_protocols() - main_protocols = {"file", "/", "gs", "http", "https", "s3"} - all_protocols = set([y.replace("://", "") for y in x]) - assert main_protocols.issubset(all_protocols) diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index 0cb4f524f9..bd20c39c53 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -49,7 +49,6 @@ def test_engine(): def test_transformer_to_literal_local(): - random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) ctx = context_manager.FlyteContext.current_context() @@ -86,6 +85,15 @@ def test_transformer_to_literal_local(): with pytest.raises(TypeError, match="No automatic conversion from "): TypeEngine.to_literal(ctx, 3, FlyteDirectory, lt) + +def test_transformer_to_literal_localss(): + random_dir = context_manager.FlyteContext.current_context().file_access.get_random_local_directory() + fs = FileAccessProvider(local_sandbox_dir=random_dir, raw_output_prefix=os.path.join(random_dir, "raw")) + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context(ctx.with_file_access(fs)) as ctx: + + tf = FlyteDirToMultipartBlobTransformer() + lt = tf.get_literal_type(FlyteDirectory) # Can't use if it's not a directory with pytest.raises(FlyteAssertion): p = "/tmp/flyte/xyz" diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index bfb41d0fef..eaba8b6343 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,9 +1,11 @@ +import os import tempfile import typing import pandas as pd import pyarrow as pa import pytest +from fsspec.utils import get_protocol from typing_extensions import Annotated import flytekit.configuration @@ -25,7 +27,6 @@ StructuredDatasetTransformerEngine, convert_schema_type_to_structured_dataset_type, extract_cols_and_format, - protocol_prefix, ) my_cols = kwtypes(w=typing.Dict[str, typing.Dict[str, int]], x=typing.List[typing.List[int]], y=int, z=str) @@ -44,8 +45,8 @@ def test_protocol(): - assert protocol_prefix("s3://my-s3-bucket/file") == "s3" - assert protocol_prefix("/file") == "file" + assert get_protocol("s3://my-s3-bucket/file") == "s3" + assert get_protocol("/file") == "file" def generate_pandas() -> pd.DataFrame: @@ -74,7 +75,6 @@ def t1(a: pd.DataFrame) -> pd.DataFrame: def test_setting_of_unset_formats(): - custom = Annotated[StructuredDataset, "parquet"] example = custom(dataframe=df, uri="/path") # It's okay that the annotation is not used here yet. @@ -89,7 +89,9 @@ def t2(path: str) -> StructuredDataset: def wf(path: str) -> StructuredDataset: return t2(path=path) - res = wf(path="/tmp/somewhere") + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "somewhere") + res = wf(path=fname) # Now that it's passed through an encoder however, it should be set. assert res.file_format == "parquet" @@ -281,7 +283,10 @@ def encode( # Check that registering with a / triggers the file protocol instead. StructuredDatasetTransformerEngine.register(TempEncoder("/")) - assert StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("file") is not None + res = StructuredDatasetTransformerEngine.get_encoder(MyDF, "file", "/") + # Test that the one we got was registered under fsspec + assert res is StructuredDatasetTransformerEngine.ENCODERS[MyDF].get("fsspec")["/"] + assert res is not None def test_sd(): diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index c7aa5563f9..cef124ffd0 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -50,5 +50,5 @@ def test_arrow(): assert encoder.protocol is None assert decoder.protocol is None assert encoder.python_type is decoder.python_type - d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["s3"]["parquet"] + d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["fsspec"]["parquet"] assert d is not None diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 373a536769..1913deb6bf 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -490,11 +490,13 @@ def t1(path: str) -> DatasetStruct: def wf(path: str) -> DatasetStruct: return t1(path=path) - res = wf(path="/tmp/somewhere") - assert "parquet" == res.a.file_format - assert "parquet" == res.b.a.file_format - assert_frame_equal(df, res.a.open(pd.DataFrame).all()) - assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "df_file") + res = wf(path=fname) + assert "parquet" == res.a.file_format + assert "parquet" == res.b.a.file_format + assert_frame_equal(df, res.a.open(pd.DataFrame).all()) + assert_frame_equal(df, res.b.a.open(pd.DataFrame).all()) def test_wf1_with_map(): diff --git a/tests/flytekit/unit/core/tracker/test_arrow_data.py b/tests/flytekit/unit/core/tracker/test_arrow_data.py new file mode 100644 index 0000000000..747e7f1651 --- /dev/null +++ b/tests/flytekit/unit/core/tracker/test_arrow_data.py @@ -0,0 +1,29 @@ +import typing + +import pandas as pd +import pyarrow as pa +from typing_extensions import Annotated + +from flytekit import kwtypes, task + +cols = kwtypes(Name=str, Age=int) +subset_cols = kwtypes(Name=str) + + +@task +def t1( + df1: Annotated[pd.DataFrame, cols], df2: Annotated[pa.Table, cols] +) -> typing.Tuple[Annotated[pd.DataFrame, subset_cols], Annotated[pa.Table, subset_cols]]: + return df1, df2 + + +def test_structured_dataset_wf(): + pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + pa_df = pa.Table.from_pandas(pd_df) + + subset_pd_df = pd.DataFrame({"Name": ["Tom", "Joseph"]}) + subset_pa_df = pa.Table.from_pandas(subset_pd_df) + + df1, df2 = t1(df1=pd_df, df2=pa_df) + assert df1.equals(subset_pd_df) + assert df2.equals(subset_pa_df) diff --git a/tests/flytekit/unit/extras/persistence/__init__.py b/tests/flytekit/unit/extras/persistence/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py b/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py deleted file mode 100644 index d2c50cc4a9..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_gcs_gsutil.py +++ /dev/null @@ -1,35 +0,0 @@ -import mock - -from flytekit import GCSPersistence - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1") - mock_exec.assert_called_with(["gsutil", "cp", "/test", "gs://my-bucket/k1"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_put_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.put("/test", "gs://my-bucket/k1", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "/test/*", "gs://my-bucket/k1/"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test") - mock_exec.assert_called_with(["gsutil", "cp", "gs://my-bucket/k1", "/test"]) - - -@mock.patch("flytekit.extras.persistence.gcs_gsutil._update_cmd_config_and_execute") -@mock.patch("flytekit.extras.persistence.gcs_gsutil.GCSPersistence._check_binary") -def test_get_recursive(mock_check, mock_exec): - proxy = GCSPersistence() - proxy.get("gs://my-bucket/k1", "/test", True) - mock_exec.assert_called_with(["gsutil", "cp", "-r", "gs://my-bucket/k1/*", "/test"]) diff --git a/tests/flytekit/unit/extras/persistence/test_http.py b/tests/flytekit/unit/extras/persistence/test_http.py deleted file mode 100644 index 893b43f364..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_http.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from flytekit import HttpPersistence - - -def test_put(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.put("", "", recursive=True) - - -def test_construct_path(): - proxy = HttpPersistence() - with pytest.raises(AssertionError): - proxy.construct_path(True, False, "", "") - - -def test_exists(): - proxy = HttpPersistence() - assert proxy.exists("https://flyte.org") diff --git a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py b/tests/flytekit/unit/extras/persistence/test_s3_awscli.py deleted file mode 100644 index a6f29f36d6..0000000000 --- a/tests/flytekit/unit/extras/persistence/test_s3_awscli.py +++ /dev/null @@ -1,80 +0,0 @@ -from datetime import timedelta - -import mock - -from flytekit import S3Persistence -from flytekit.configuration import DataConfig, S3Config -from flytekit.extras.persistence import s3_awscli - - -def test_property(): - aws = S3Persistence("s3://raw-output") - assert aws.default_prefix == "s3://raw-output" - - -def test_construct_path(): - aws = S3Persistence() - p = aws.construct_path(True, False, "xyz") - assert p == "s3://xyz" - - -@mock.patch("flytekit.extras.persistence.s3_awscli.S3Persistence._check_binary") -@mock.patch("flytekit.extras.persistence.s3_awscli.subprocess") -def test_retries(mock_subprocess, mock_check): - mock_subprocess.check_call.side_effect = Exception("test exception (404)") - mock_check.return_value = True - - proxy = S3Persistence(data_config=DataConfig(s3=S3Config(backoff=timedelta(seconds=0)))) - assert proxy.exists("s3://test/fdsa/fdsa") is False - assert mock_subprocess.check_call.call_count == 8 - - -def test_extra_args(): - assert s3_awscli._extra_args({}) == [] - assert s3_awscli._extra_args({"ContentType": "ct"}) == ["--content-type", "ct"] - assert s3_awscli._extra_args({"ContentEncoding": "ec"}) == ["--content-encoding", "ec"] - assert s3_awscli._extra_args({"ACL": "acl"}) == ["--acl", "acl"] - assert s3_awscli._extra_args({"ContentType": "ct", "ContentEncoding": "ec", "ACL": "acl"}) == [ - "--content-type", - "ct", - "--content-encoding", - "ec", - "--acl", - "acl", - ] - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1") - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_put_recursive(mock_exec): - proxy = S3Persistence() - proxy.put("/test", "s3://my-bucket/k1", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "--acl", "bucket-owner-full-control", "/test", "s3://my-bucket/k1"], - s3_cfg=S3Config.auto(), - ) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test") - mock_exec.assert_called_with(cmd=["aws", "s3", "cp", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto()) - - -@mock.patch("flytekit.extras.persistence.s3_awscli._update_cmd_config_and_execute") -def test_get_recursive(mock_exec): - proxy = S3Persistence() - proxy.get("s3://my-bucket/k1", "/test", True) - mock_exec.assert_called_with( - cmd=["aws", "s3", "cp", "--recursive", "s3://my-bucket/k1", "/test"], s3_cfg=S3Config.auto() - ) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4b8f82fb7e..5e20eaeee3 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -177,14 +177,6 @@ def test_more_stuff(mock_client): with tempfile.TemporaryDirectory() as tmp_dir: r._upload_file(pathlib.Path(tmp_dir)) - # Test that this copies the file. - with tempfile.TemporaryDirectory() as tmp_dir: - mm = MagicMock() - mm.signed_url = os.path.join(tmp_dir, "tmp_file") - mock_client.return_value.get_upload_signed_url.return_value = mm - - r._upload_file(pathlib.Path(__file__)) - serialization_settings = flytekit.configuration.SerializationSettings( project="project", domain="domain",