Skip to content

Commit

Permalink
Set default format of structured dataset to empty (#1159)
Browse files Browse the repository at this point in the history
* Set default format of structured dataset to empty

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* last error (#1364)

Signed-off-by: Yee Hing Tong <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw and wild-endeavor authored Dec 9, 2022
1 parent 56014d7 commit 6d78c56
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 67 deletions.
7 changes: 5 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A
if issubclass(python_type, FlyteFile) or issubclass(python_type, FlyteDirectory):
return python_type(path=lv.scalar.blob.uri)
elif issubclass(python_type, StructuredDataset):
return python_type(uri=lv.scalar.structured_dataset.uri)
sd = python_type(uri=lv.scalar.structured_dataset.uri)
sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
return sd
else:
return python_val
else:
Expand Down Expand Up @@ -534,7 +536,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
f"serialized correctly"
)

dc = cast(DataClassJsonMixin, expected_python_type).from_json(_json_format.MessageToJson(lv.scalar.generic))
json_str = _json_format.MessageToJson(lv.scalar.generic)
dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str)
return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type))

# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run``
Expand Down
8 changes: 4 additions & 4 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def decode(
return pq.read_table(local_dir)


StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(), default_format_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(), default_format_for_type=True)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(), default_format_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(), default_format_for_type=True)

StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
136 changes: 91 additions & 45 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# Storage formats
PARQUET: StructuredDatasetFormat = "parquet"
GENERIC_FORMAT: StructuredDatasetFormat = ""


@dataclass_json
Expand All @@ -45,9 +46,7 @@ class (that is just a model, a Python class representation of the protobuf).
"""

uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String()))
file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String()))

DEFAULT_FILE_FORMAT = PARQUET
file_format: typing.Optional[str] = field(default=GENERIC_FORMAT, metadata=config(mm_field=fields.String()))

@classmethod
def columns(cls) -> typing.Dict[str, typing.Type]:
Expand All @@ -68,6 +67,8 @@ def __init__(
# Make these fields public, so that the dataclass transformer can set a value for it
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
self.uri = uri
# When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
# This is a special attribute that indicates if the data was either downloaded or uploaded
self._metadata = metadata
# This is not for users to set, the transformer will set this.
Expand Down Expand Up @@ -128,14 +129,14 @@ def extract_cols_and_format(
optional str for the format,
optional pyarrow Schema
"""
fmt = None
fmt = ""
ordered_dict_cols = None
pa_schema = None
if get_origin(t) is Annotated:
base_type, *annotate_args = get_args(t)
for aa in annotate_args:
if isinstance(aa, StructuredDatasetFormat):
if fmt is not None:
if fmt != "":
raise ValueError(f"A format was already specified {fmt}, cannot use {aa}")
fmt = aa
elif isinstance(aa, collections.OrderedDict):
Expand Down Expand Up @@ -334,21 +335,44 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]
Renderers: Dict[Type, Renderable] = {}

@staticmethod
def _finder(handler_map, df_type: Type, protocol: str, format: str):
try:
return handler_map[df_type][protocol][format]
except KeyError:
@classmethod
def _finder(cls, handler_map, df_type: Type, protocol: str, format: str):
# If the incoming format requested is a specific format (e.g. "avro"), then look for that specific handler
# if missing, see if there's a generic format handler. Error if missing.
# If the incoming format requested is the generic format (""), then see if it's present,
# if not, look to see if there is a default format for the df_type and a handler for that format.
# if still missing, look to see if there's only _one_ handler for that type, if so then use that.
if format != GENERIC_FORMAT:
try:
hh = handler_map[df_type][protocol][""]
logger.info(
f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}"
f" format {format}, using default instead."
)
return hh
return handler_map[df_type][protocol][format]
except KeyError:
try:
return handler_map[df_type][protocol][GENERIC_FORMAT]
except KeyError:
...
else:
try:
return handler_map[df_type][protocol][GENERIC_FORMAT]
except KeyError:
...
raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}")
if df_type in cls.DEFAULT_FORMATS and cls.DEFAULT_FORMATS[df_type] in handler_map[df_type][protocol]:
hh = handler_map[df_type][protocol][cls.DEFAULT_FORMATS[df_type]]
logger.debug(
f"Didn't find format specific handler {type(handler_map)} for protocol {protocol}"
f" using the generic handler {hh} instead."
)
return hh
if len(handler_map[df_type][protocol]) == 1:
hh = list(handler_map[df_type][protocol].values())[0]
logger.debug(
f"Using {hh} with format {hh.supported_format} as it's the only one available for {df_type}"
)
return hh
else:
logger.warning(
f"Did not automatically pick a handler for {df_type},"
f" more than one detected {handler_map[df_type][protocol].keys()}"
)
raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt |{format}|")

@classmethod
def get_encoder(cls, df_type: Type, protocol: str, format: str):
Expand Down Expand Up @@ -381,7 +405,14 @@ def register_renderer(cls, python_type: Type, renderer: Renderable):
cls.Renderers[python_type] = renderer

@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
def register(
cls,
h: Handlers,
default_for_type: bool = False,
override: bool = False,
default_format_for_type: bool = False,
default_storage_for_type: bool = False,
):
"""
Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
specify a protocol (e.g. s3, gs, etc.) field, then
Expand All @@ -395,6 +426,10 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid
In these cases, the protocol is determined by the raw output data prefix set in the active context.
:param override: Override any previous registrations. If default_for_type is also set, this will also override
the default.
:param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage
as the default, this will only set the format. Error if already set, unless override is specified.
:param default_storage_for_type: Same as above but only for the storage format. Error if already set,
unless override is specified.
"""
if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)):
raise TypeError(f"We don't support this type of handler {h}")
Expand All @@ -409,17 +444,29 @@ def register(cls, h: Handlers, default_for_type: Optional[bool] = False, overrid
stripped = DataPersistencePlugins.get_protocol(persistence_protocol)
logger.debug(f"Automatically registering {persistence_protocol} as {stripped} with {h}")
try:
cls.register_for_protocol(h, stripped, False, override)
cls.register_for_protocol(
h, stripped, False, override, default_format_for_type, default_storage_for_type
)
except DuplicateHandlerError:
logger.debug(f"Skipping {persistence_protocol}/{stripped} for {h} because duplicate")

elif h.protocol == "":
raise ValueError(f"Use None instead of empty string for registering handler {h}")
else:
cls.register_for_protocol(h, h.protocol, default_for_type, override)
cls.register_for_protocol(
h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type
)

@classmethod
def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: bool, override: bool):
def register_for_protocol(
cls,
h: Handlers,
protocol: str,
default_for_type: bool,
override: bool,
default_format_for_type: bool,
default_storage_for_type: bool,
):
"""
See the main register function instead.
"""
Expand All @@ -434,12 +481,24 @@ def register_for_protocol(cls, h: Handlers, protocol: str, default_for_type: boo
lowest_level[h.supported_format] = h
logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")

if default_for_type:
logger.debug(
f"Using storage {protocol} and format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
)
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
cls.DEFAULT_PROTOCOLS[h.python_type] = protocol
if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
if h.python_type in cls.DEFAULT_FORMATS and not override:
logger.warning(
f"Not using handler {h} with format {h.supported_format} as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
)
else:
logger.debug(
f"Setting format {h.supported_format} for dataframes of type {h.python_type} from handler {h}"
)
cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
if default_storage_for_type or default_for_type:
if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
logger.warning(
f"Not using handler {h} with storage protocol {h.protocol} as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified."
)
else:
logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}")
cls.DEFAULT_PROTOCOLS[h.python_type] = protocol

# Register with the type engine as well
# The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
Expand All @@ -461,7 +520,7 @@ def to_literal(
# Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema
python_type, *attrs = extract_cols_and_format(python_type)
# In case it's a FlyteSchema
sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None))
sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))

if expected and expected.structured_dataset_type:
sdt = StructuredDatasetType(
Expand Down Expand Up @@ -514,16 +573,12 @@ def to_literal(
python_val,
df_type,
protocol,
sdt.format or typing.cast(StructuredDataset, python_val).DEFAULT_FILE_FORMAT,
sdt.format,
sdt,
)

# Otherwise assume it's a dataframe instance. Wrap it with some defaults
if python_type in self.DEFAULT_FORMATS:
fmt = self.DEFAULT_FORMATS[python_type]
else:
logger.debug(f"No default format for type {python_type}, using system default.")
fmt = StructuredDataset.DEFAULT_FILE_FORMAT
fmt = self.DEFAULT_FORMATS.get(python_type, "")
protocol = self._protocol_from_type_or_prefix(ctx, python_type)
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None)

Expand Down Expand Up @@ -760,18 +815,9 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]

# Get the column information
converted_cols = self._convert_ordered_dict_of_columns_to_list(column_map)

# Get the format
default_format = (
original_python_type.DEFAULT_FILE_FORMAT
if issubclass(original_python_type, StructuredDataset)
else self.DEFAULT_FORMATS.get(original_python_type, PARQUET)
)
fmt = storage_format or default_format

return StructuredDatasetType(
columns=converted_cols,
format=fmt,
format=storage_format,
external_schema_type="arrow" if pa_schema else None,
external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_flyte_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ def t1(data: Annotated[Union[np.ndarray, pd.DataFrame, Sequence], "some annotati
task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
variants = task_spec.template.interface.inputs["data"].type.union_type.variants
assert variants[0].blob.format == "NumpyArray"
assert variants[1].structured_dataset_type.format == "parquet"
assert variants[1].structured_dataset_type.format == ""
assert variants[2].blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
5 changes: 1 addition & 4 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDatasetType

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -373,6 +372,4 @@ def ref_t2(

assert len(wf_spec.template.interface.outputs) == 1
assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type is not None
assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type == StructuredDatasetType(
format="parquet"
)
assert wf_spec.template.interface.outputs["output_from_t3"].type.structured_dataset_type.format == ""
Loading

0 comments on commit 6d78c56

Please sign in to comment.