Skip to content

Commit

Permalink
update according to yees to_html refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
esadler-hbo committed Aug 31, 2022
1 parent d5e231b commit e242ff7
Showing 1 changed file with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,27 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
GCS,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)


class HuggingFaceDatasetRenderer:
"""
The datasets Dataset printable representation is saved to HTML.
"""

def to_html(self, df: datasets.Dataset) -> str:
assert isinstance(df, datasets.Dataset)
return str(df).replace("\n", "<br>")


class HuggingFaceDatasetToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(datasets.Dataset, protocol, PARQUET)
def __init__(self):
super().__init__(datasets.Dataset, None, PARQUET)

def encode(
self,
Expand All @@ -41,8 +48,8 @@ def encode(


class ParquetToHuggingFaceDatasetDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(datasets.Dataset, protocol, PARQUET)
def __init__(self):
super().__init__(datasets.Dataset, None, PARQUET)

def decode(
self,
Expand All @@ -60,12 +67,6 @@ def decode(
return datasets.Dataset.from_parquet(path)


for protocol in [LOCAL, S3]:
StructuredDatasetTransformerEngine.register(
HuggingFaceDatasetToParquetEncodingHandler(protocol), default_for_type=True
)
StructuredDatasetTransformerEngine.register(
ParquetToHuggingFaceDatasetDecodingHandler(protocol), default_for_type=True
)
StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(datasets.Dataset, HuggingFaceDatasetRenderer())

0 comments on commit e242ff7

Please sign in to comment.