diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 0690179bb1..579efd366c 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -1,3 +1,4 @@ +import os import typing import datasets @@ -59,12 +60,11 @@ def decode( ) -> datasets.Dataset: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" - + files = [item.path for item in os.scandir(local_dir)] if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return datasets.Dataset.from_parquet(path, columns=columns) - return datasets.Dataset.from_parquet(path) + return datasets.Dataset.from_parquet(files, columns=columns) + return datasets.Dataset.from_parquet(files) StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 477c7a1a7c..acdbc20810 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -38,4 +38,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py index 170fdc3789..5b65b2511c 100644 --- a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -68,3 +68,15 @@ def test_datasets_renderer(): df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) dataset = datasets.Dataset.from_pandas(df) assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
") + + +def test_parquet_to_datasets(): + df = pd.DataFrame({"name": ["Alice"], "age": [10]}) + + @task + def create_sd() -> StructuredDataset: + return StructuredDataset(dataframe=df) + + sd = create_sd() + dataset = sd.open(datasets.Dataset).all() + assert dataset.data == datasets.Dataset.from_pandas(df).data diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0dfd0c6516..0b5bf8e577 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -64,11 +64,10 @@ def decode( ) -> pl.DataFrame: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(path, columns=columns) - return pl.read_parquet(path) + return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) + return pl.read_parquet(local_dir, use_pyarrow=True) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index b991cd5d13..15a195e5d5 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -66,3 +66,16 @@ def test_polars_renderer(): assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( df.describe().transpose(), columns=df.describe().columns ).to_html(index=False) + + +def test_parquet_to_polars(): + data = {"name": ["Alice"], "age": [5]} + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame(data=data) + return StructuredDataset(dataframe=df) + + sd = create_sd() + polars_df = sd.open(pl.DataFrame).all() + assert pl.DataFrame(data).frame_equal(polars_df) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 46079f40dd..386570be5c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,6 +1,7 @@ import typing import pandas as pd +import pyspark from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -38,7 +39,10 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - df.write.mode("overwrite").parquet(path) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + # Avoid generating SUCCESS files + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 96081bd789..11935a30af 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -34,4 +34,5 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], scripts=["scripts/flytekit_install_spark3.sh"], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, )