Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow None protocol to mean all data persistence supported storage options in Structured Dataset #1134

Merged
merged 22 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def is_supported_protocol(cls, protocol: str) -> bool:
"""
return protocol in cls._PLUGINS

@classmethod
def supported_protocols(cls) -> typing.List[str]:
return [k for k in cls._PLUGINS.keys()]


class DiskPersistence(DataPersistence):
"""
Expand Down
33 changes: 12 additions & 21 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
import pyarrow.parquet as pq

from flytekit import FlyteContext
from flytekit.core.data_persistence import DataPersistencePlugins
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
ABFS,
GCS,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -27,10 +22,8 @@


class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(pd.DataFrame, protocol, PARQUET)
# todo: Use this somehow instead of relaying ont he ctx file_access
self._persistence = DataPersistencePlugins.find_plugin(protocol)()
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)

def encode(
self,
Expand All @@ -50,8 +43,8 @@ def encode(


class ParquetToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(pd.DataFrame, protocol, PARQUET)
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)

def decode(
self,
Expand All @@ -69,8 +62,8 @@ def decode(


class ArrowToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(pa.Table, protocol, PARQUET)
def __init__(self):
super().__init__(pa.Table, None, PARQUET)

def encode(
self,
Expand All @@ -88,8 +81,8 @@ def encode(


class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(pa.Table, protocol, PARQUET)
def __init__(self):
super().__init__(pa.Table, None, PARQUET)

def decode(
self,
Expand All @@ -106,9 +99,7 @@ def decode(
return pq.read_table(local_dir)


# Don't override default protocol
for protocol in [LOCAL, S3, GCS, ABFS]:
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())
8 changes: 4 additions & 4 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def decode(
return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata))


StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers())
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers())
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler())
109 changes: 74 additions & 35 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections
import importlib
import os
import re
import types
import typing
from abc import ABC, abstractmethod
Expand All @@ -12,15 +11,16 @@

import _datetime
import numpy as _np
import pandas
import pandas as pd
import pyarrow
import pyarrow as pa

from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence

if importlib.util.find_spec("pyspark") is not None:
import pyspark
if importlib.util.find_spec("polars") is not None:
import polars as pl

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
Expand All @@ -36,13 +36,6 @@
T = typing.TypeVar("T") # StructuredDataset type or a dataframe type
DF = typing.TypeVar("DF") # Dataframe type

# Protocols
BIGQUERY = "bq"
S3 = "s3"
ABFS = "abfs"
GCS = "gs"
LOCAL = "/"

# For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy.
StructuredDatasetFormat: TypeAlias = str

Expand Down Expand Up @@ -156,7 +149,7 @@ def extract_cols_and_format(
if ordered_dict_cols is not None:
raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}")
ordered_dict_cols = aa
elif isinstance(aa, pyarrow.Schema):
elif isinstance(aa, pa.Schema):
if pa_schema is not None:
raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}")
pa_schema = aa
Expand All @@ -168,7 +161,7 @@ def extract_cols_and_format(


class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None):
def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -184,15 +177,15 @@ def __init__(self, python_type: Type[T], protocol: str, supported_format: Option
for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "")
self._protocol = protocol.replace("://", "") if protocol else None
self._supported_format = supported_format or ""

@property
def python_type(self) -> Type[T]:
return self._python_type

@property
def protocol(self) -> str:
def protocol(self) -> Optional[str]:
return self._protocol

@property
Expand Down Expand Up @@ -228,7 +221,7 @@ def encode(


class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optional[str] = None):
def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -243,15 +236,15 @@ def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optio
for does not exist, the transformer enginer will look for the "" decoder instead and write a warning.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "")
self._protocol = protocol.replace("://", "") if protocol else None
self._supported_format = supported_format or ""

@property
def python_type(self) -> Type[DF]:
return self._python_type

@property
def protocol(self) -> str:
def protocol(self) -> Optional[str]:
return self._protocol

@property
Expand Down Expand Up @@ -281,10 +274,8 @@ def decode(


def protocol_prefix(uri: str) -> str:
g = re.search(r"([\w]+)://.*", uri)
if g and g.groups():
return g.groups()[0]
return LOCAL
p = DataPersistencePlugins.get_protocol(uri)
return p


def convert_schema_type_to_structured_dataset_type(
Expand All @@ -306,6 +297,10 @@ def convert_schema_type_to_structured_dataset_type(
raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}")


class DuplicateHandlerError(ValueError):
...


class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
"""
Think of this transformer as a higher-level meta transformer that is used for all the dataframe types.
Expand Down Expand Up @@ -366,8 +361,7 @@ def get_decoder(cls, df_type: Type, protocol: str, format: str):
return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format)

@classmethod
def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]:
# Maybe think about default dict in the future, but is typing as nice?
def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]:
if isinstance(h, StructuredDatasetEncoder):
top_level = cls.ENCODERS
elif isinstance(h, StructuredDatasetDecoder):
Expand All @@ -376,9 +370,9 @@ def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]:
raise TypeError(f"We don't support this type of handler {h}")
if h.python_type not in top_level:
top_level[h.python_type] = {}
if h.protocol not in top_level[h.python_type]:
top_level[h.python_type][h.protocol] = {}
return top_level[h.python_type][h.protocol]
if protocol not in top_level[h.python_type]:
top_level[h.python_type][protocol] = {}
return top_level[h.python_type][protocol]

def __init__(self):
super().__init__("StructuredDataset Transformer", StructuredDataset)
Expand All @@ -388,22 +382,67 @@ def __init__(self):
self._hash_overridable = True

@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
"""
Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
specify a protocol (e.g. s3, gs, etc.) field, then

:param h: The StructuredDatasetEncoder or StructuredDatasetDecoder you wish to register with this transformer.
:param default_for_type: If set, when a user returns from a task an instance of the dataframe the handler
handles, e.g. ``return pd.DataFrame(...)``, not wrapped around the ``StructuredDataset`` object, we will
use this handler's protocol and format as the default, effectively saying that this handler will be called.
Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler
is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of.
In these cases, the protocol is determined by the raw output data prefix set in the active context.
:param override: Override any previous registrations. If default_for_type is also set, this will also override
the default.
"""
Call this with any handler to register it with this dataframe meta-transformer
if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)):
raise TypeError(f"We don't support this type of handler {h}")

The string "://" should not be present in any handler's protocol so we don't check for it.
if h.protocol is None:
if default_for_type:
raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.")
for persistence_protocol in DataPersistencePlugins.supported_protocols():
# TODO: Clean this up when we get to replacing the persistence layer.
# The behavior of the protocols given in the supported_protocols and is_supported_protocol
# is not actually the same as the one returned in get_protocol.
stripped = DataPersistencePlugins.get_protocol(persistence_protocol)
logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}")
try:
cls.register_for_protocol(h, stripped, False, override)
except DuplicateHandlerError:
...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. nit: should we add debug message here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. i will. will need another +1 later, will get from @eapolinario . need to add more tests.

# Add this?
# cls.register_for_protocol(h, "/", False, override)

elif h.protocol == "":
raise ValueError(f"Use None instead of empty string for registering handler {h}")
else:
cls.register_for_protocol(h, h.protocol, default_for_type, override)

@classmethod
def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool):
"""
See the main register function instead.
"""
lowest_level = cls._handler_finder(h)
if protocol == "/":
# TODO: Special fix again, because get_protocol returns file, instead of file://
protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL)
lowest_level = cls._handler_finder(h, protocol)
if h.supported_format in lowest_level and override is False:
raise ValueError(f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}")
raise DuplicateHandlerError(
f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
)
lowest_level[h.supported_format] = h
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}")
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")

if default_for_type:
# TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist.
logger.debug(
f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
)
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol
cls.DEFAULT_PROTOCOLS[h.python_type] = protocol

# Register with the type engine as well
# The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
Expand Down Expand Up @@ -657,7 +696,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ
else:
df = python_val

if isinstance(df, pandas.DataFrame):
if isinstance(df, pd.DataFrame):
return df.describe().to_html()
elif isinstance(df, pa.Table):
return df.to_string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import importlib

from flytekit import StructuredDatasetTransformerEngine, logger
from flytekit.configuration import internal
from flytekit.types.structured.structured_dataset import ABFS, GCS, S3

from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler
from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler
from .persist import FSSpecPersistence

S3 = "s3"
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
ABFS = "abfs"
GCS = "gs"


def _register(protocol: str):
logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -22,7 +21,7 @@

def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]:
protocol = FSSpecPersistence.get_protocol(uri)
if protocol == S3:
if protocol == "s3":
kwargs = s3_setup_args(cfg.s3)
if kwargs:
return kwargs
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
flyteidl>=1.0.0
git+https://github.com/flyteorg/flytekit@master#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
git+https://github.com/flyteorg/flytekit@sd-data-persistence#egg=flytekitplugins-spark&subdirectory=plugins/flytekit-spark
# vcs+protocol://repo_url/#egg=pkg&subdirectory=flyte
2 changes: 1 addition & 1 deletion plugins/flytekit-papermill/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ flyteidl==1.0.0.post1
# flytekit
flytekit==1.1.0b0
# via flytekitplugins-spark
flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@master#subdirectory=plugins/flytekit-spark
flytekitplugins-spark @ git+https://github.com/flyteorg/flytekit@sd-data-persistence#subdirectory=plugins/flytekit-spark
# via -r dev-requirements.in
googleapis-common-protos==1.55.0
# via
Expand Down
Loading