Skip to content

Commit

Permalink
Allow None protocol to mean all data persistence supported storage op…
Browse files Browse the repository at this point in the history
…tions in Structured Dataset (#1134)

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored and eapolinario committed Sep 15, 2022
1 parent f51b15d commit ad5b19b
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 118 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def is_supported_protocol(cls, protocol: str) -> bool:
"""
return protocol in cls._PLUGINS

@classmethod
def supported_protocols(cls) -> typing.List[str]:
return [k for k in cls._PLUGINS.keys()]


class DiskPersistence(DataPersistence):
"""
Expand Down
33 changes: 12 additions & 21 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,11 @@
import pyarrow.parquet as pq

from flytekit import FlyteContext
from flytekit.core.data_persistence import DataPersistencePlugins
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
ABFS,
GCS,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -27,10 +22,8 @@


class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(pd.DataFrame, protocol, PARQUET)
# todo: Use this somehow instead of relaying ont he ctx file_access
self._persistence = DataPersistencePlugins.find_plugin(protocol)()
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)

def encode(
self,
Expand All @@ -50,8 +43,8 @@ def encode(


class ParquetToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(pd.DataFrame, protocol, PARQUET)
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)

def decode(
self,
Expand All @@ -69,8 +62,8 @@ def decode(


class ArrowToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(pa.Table, protocol, PARQUET)
def __init__(self):
super().__init__(pa.Table, None, PARQUET)

def encode(
self,
Expand All @@ -88,8 +81,8 @@ def encode(


class ParquetToArrowDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(pa.Table, protocol, PARQUET)
def __init__(self):
super().__init__(pa.Table, None, PARQUET)

def decode(
self,
Expand All @@ -106,9 +99,7 @@ def decode(
return pq.read_table(local_dir)


# Don't override default protocol
for protocol in [LOCAL, S3, GCS, ABFS]:
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())
11 changes: 6 additions & 5 deletions flytekit/types/structured/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from flytekit.models import literals
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
BIGQUERY,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetMetadata,
StructuredDatasetTransformerEngine,
)

BIGQUERY = "bq"


def _write_to_bq(structured_dataset: StructuredDataset):
table_id = typing.cast(str, structured_dataset.uri).split("://", 1)[1].replace(":", ".")
Expand Down Expand Up @@ -111,7 +112,7 @@ def decode(
return pa.Table.from_pandas(_read_from_bq(flyte_value, current_task_metadata))


StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers(), default_for_type=False)
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler(), default_for_type=False)
StructuredDatasetTransformerEngine.register(PandasToBQEncodingHandlers())
StructuredDatasetTransformerEngine.register(BQToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToBQEncodingHandlers())
StructuredDatasetTransformerEngine.register(BQToArrowDecodingHandler())
111 changes: 76 additions & 35 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import collections
import importlib
import os
import re
import types
import typing
from abc import ABC, abstractmethod
Expand All @@ -12,15 +11,16 @@

import _datetime
import numpy as _np
import pandas
import pandas as pd
import pyarrow
import pyarrow as pa

from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence

if importlib.util.find_spec("pyspark") is not None:
import pyspark
if importlib.util.find_spec("polars") is not None:
import polars as pl

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
Expand All @@ -36,13 +36,6 @@
T = typing.TypeVar("T") # StructuredDataset type or a dataframe type
DF = typing.TypeVar("DF") # Dataframe type

# Protocols
BIGQUERY = "bq"
S3 = "s3"
ABFS = "abfs"
GCS = "gs"
LOCAL = "/"

# For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy.
StructuredDatasetFormat: TypeAlias = str

Expand Down Expand Up @@ -156,7 +149,7 @@ def extract_cols_and_format(
if ordered_dict_cols is not None:
raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}")
ordered_dict_cols = aa
elif isinstance(aa, pyarrow.Schema):
elif isinstance(aa, pa.Schema):
if pa_schema is not None:
raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}")
pa_schema = aa
Expand All @@ -168,7 +161,7 @@ def extract_cols_and_format(


class StructuredDatasetEncoder(ABC):
def __init__(self, python_type: Type[T], protocol: str, supported_format: Optional[str] = None):
def __init__(self, python_type: Type[T], protocol: Optional[str] = None, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the encode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -179,20 +172,22 @@ def __init__(self, python_type: Type[T], protocol: str, supported_format: Option
:param python_type: The dataframe class in question that you want to register this encoder with
:param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
"s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
If None, this encoder will be registered with all protocols that flytekit's data persistence layer
is capable of handling.
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the encoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" endcoder instead and write a warning.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "")
self._protocol = protocol.replace("://", "") if protocol else None
self._supported_format = supported_format or ""

@property
def python_type(self) -> Type[T]:
return self._python_type

@property
def protocol(self) -> str:
def protocol(self) -> Optional[str]:
return self._protocol

@property
Expand Down Expand Up @@ -228,7 +223,7 @@ def encode(


class StructuredDatasetDecoder(ABC):
def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optional[str] = None):
def __init__(self, python_type: Type[DF], protocol: Optional[str] = None, supported_format: Optional[str] = None):
"""
Extend this abstract class, implement the decode function, and register your concrete class with the
StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
Expand All @@ -238,20 +233,22 @@ def __init__(self, python_type: Type[DF], protocol: str, supported_format: Optio
:param python_type: The dataframe class in question that you want to register this decoder with
:param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
"s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
If None, this decoder will be registered with all protocols that flytekit's data persistence layer
is capable of handling.
:param supported_format: Arbitrary string representing the format. If not supplied then an empty string
will be used. An empty string implies that the decoder works with any format. If the format being asked
for does not exist, the transformer enginer will look for the "" decoder instead and write a warning.
"""
self._python_type = python_type
self._protocol = protocol.replace("://", "")
self._protocol = protocol.replace("://", "") if protocol else None
self._supported_format = supported_format or ""

@property
def python_type(self) -> Type[DF]:
return self._python_type

@property
def protocol(self) -> str:
def protocol(self) -> Optional[str]:
return self._protocol

@property
Expand Down Expand Up @@ -281,10 +278,8 @@ def decode(


def protocol_prefix(uri: str) -> str:
g = re.search(r"([\w]+)://.*", uri)
if g and g.groups():
return g.groups()[0]
return LOCAL
p = DataPersistencePlugins.get_protocol(uri)
return p


def convert_schema_type_to_structured_dataset_type(
Expand All @@ -306,6 +301,10 @@ def convert_schema_type_to_structured_dataset_type(
raise AssertionError(f"Unrecognized SchemaColumnType: {column_type}")


class DuplicateHandlerError(ValueError):
...


class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
"""
Think of this transformer as a higher-level meta transformer that is used for all the dataframe types.
Expand Down Expand Up @@ -366,8 +365,7 @@ def get_decoder(cls, df_type: Type, protocol: str, format: str):
return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format)

@classmethod
def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]:
# Maybe think about default dict in the future, but is typing as nice?
def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]:
if isinstance(h, StructuredDatasetEncoder):
top_level = cls.ENCODERS
elif isinstance(h, StructuredDatasetDecoder):
Expand All @@ -376,9 +374,9 @@ def _handler_finder(cls, h: Handlers) -> Dict[str, Handlers]:
raise TypeError(f"We don't support this type of handler {h}")
if h.python_type not in top_level:
top_level[h.python_type] = {}
if h.protocol not in top_level[h.python_type]:
top_level[h.python_type][h.protocol] = {}
return top_level[h.python_type][h.protocol]
if protocol not in top_level[h.python_type]:
top_level[h.python_type][protocol] = {}
return top_level[h.python_type][protocol]

def __init__(self):
super().__init__("StructuredDataset Transformer", StructuredDataset)
Expand All @@ -388,22 +386,65 @@ def __init__(self):
self._hash_overridable = True

@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = True, override: Optional[bool] = False):
def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
"""
Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
specify a protocol (e.g. s3, gs, etc.) field, then
:param h: The StructuredDatasetEncoder or StructuredDatasetDecoder you wish to register with this transformer.
:param default_for_type: If set, when a user returns from a task an instance of the dataframe the handler
handles, e.g. ``return pd.DataFrame(...)``, not wrapped around the ``StructuredDataset`` object, we will
use this handler's protocol and format as the default, effectively saying that this handler will be called.
Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler
is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of.
In these cases, the protocol is determined by the raw output data prefix set in the active context.
:param override: Override any previous registrations. If default_for_type is also set, this will also override
the default.
"""
Call this with any handler to register it with this dataframe meta-transformer
if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)):
raise TypeError(f"We don't support this type of handler {h}")

The string "://" should not be present in any handler's protocol so we don't check for it.
if h.protocol is None:
if default_for_type:
raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.")
for persistence_protocol in DataPersistencePlugins.supported_protocols():
# TODO: Clean this up when we get to replacing the persistence layer.
# The behavior of the protocols given in the supported_protocols and is_supported_protocol
# is not actually the same as the one returned in get_protocol.
stripped = DataPersistencePlugins.get_protocol(persistence_protocol)
logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}")
try:
cls.register_for_protocol(h, stripped, False, override)
except DuplicateHandlerError:
logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate")

elif h.protocol == "":
raise ValueError(f"Use None instead of empty string for registering handler {h}")
else:
cls.register_for_protocol(h, h.protocol, default_for_type, override)

@classmethod
def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool):
"""
See the main register function instead.
"""
lowest_level = cls._handler_finder(h)
if protocol == "/":
# TODO: Special fix again, because get_protocol returns file, instead of file://
protocol = DataPersistencePlugins.get_protocol(DiskPersistence.PROTOCOL)
lowest_level = cls._handler_finder(h, protocol)
if h.supported_format in lowest_level and override is False:
raise ValueError(f"Already registered a handler for {(h.python_type, h.protocol, h.supported_format)}")
raise DuplicateHandlerError(
f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
)
lowest_level[h.supported_format] = h
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {h.protocol}, fmt {h.supported_format}")
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")

if default_for_type:
# TODO: Add logging, think about better ux, maybe default False and warn if doesn't exist.
logger.debug(
f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
)
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
cls.DEFAULT_PROTOCOLS[h.python_type] = h.protocol
cls.DEFAULT_PROTOCOLS[h.python_type] = protocol

# Register with the type engine as well
# The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
Expand Down Expand Up @@ -657,7 +698,7 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ
else:
df = python_val

if isinstance(df, pandas.DataFrame):
if isinstance(df, pd.DataFrame):
return df.describe().to_html()
elif isinstance(df, pa.Table):
return df.to_string()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import importlib

from flytekit import StructuredDatasetTransformerEngine, logger
from flytekit.configuration import internal
from flytekit.types.structured.structured_dataset import ABFS, GCS, S3

from .arrow import ArrowToParquetEncodingHandler, ParquetToArrowDecodingHandler
from .pandas import PandasToParquetEncodingHandler, ParquetToPandasDecodingHandler
from .persist import FSSpecPersistence

S3 = "s3"
ABFS = "abfs"
GCS = "gs"


def _register(protocol: str):
logger.info(f"Registering fsspec {protocol} implementations and overriding default structured encoder/decoder.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand All @@ -22,7 +21,7 @@

def get_storage_options(cfg: DataConfig, uri: str) -> typing.Optional[typing.Dict]:
protocol = FSSpecPersistence.get_protocol(uri)
if protocol == S3:
if protocol == "s3":
kwargs = s3_setup_args(cfg.s3)
if kwargs:
return kwargs
Expand Down
Loading

0 comments on commit ad5b19b

Please sign in to comment.