From e242ff7bb11e755e37e7952da660be03670b34f0 Mon Sep 17 00:00:00 2001 From: Evan Date: Wed, 31 Aug 2022 16:41:33 -0400 Subject: [PATCH] update according to yees to_html refactor --- .../huggingface/sd_transformers.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 8871910a57..4a083d9aac 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -7,10 +7,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( - GCS, - LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -18,9 +15,19 @@ ) +class HuggingFaceDatasetRenderer: + """ + The datasets Dataset printable representation is saved to HTML. + """ + + def to_html(self, df: datasets.Dataset) -> str: + assert isinstance(df, datasets.Dataset) + return str(df).replace("\n", "
") + + class HuggingFaceDatasetToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self, protocol: str): - super().__init__(datasets.Dataset, protocol, PARQUET) + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) def encode( self, @@ -41,8 +48,8 @@ def encode( class ParquetToHuggingFaceDatasetDecodingHandler(StructuredDatasetDecoder): - def __init__(self, protocol: str): - super().__init__(datasets.Dataset, protocol, PARQUET) + def __init__(self): + super().__init__(datasets.Dataset, None, PARQUET) def decode( self, @@ -60,12 +67,6 @@ def decode( return datasets.Dataset.from_parquet(path) -for protocol in [LOCAL, S3]: - StructuredDatasetTransformerEngine.register( - HuggingFaceDatasetToParquetEncodingHandler(protocol), default_for_type=True - ) - StructuredDatasetTransformerEngine.register( - ParquetToHuggingFaceDatasetDecodingHandler(protocol), default_for_type=True - ) -StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler(GCS), default_for_type=False) -StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler(GCS), default_for_type=False) +StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ParquetToHuggingFaceDatasetDecodingHandler()) +StructuredDatasetTransformerEngine.register_renderer(datasets.Dataset, HuggingFaceDatasetRenderer())