From b6f7799d6c213aec8d5268b279558b919edc5835 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 12:10:57 -0800 Subject: [PATCH 1/7] Read polars dataframe in a folder Signed-off-by: Kevin Su --- .../flytekitplugins/polars/sd_transformers.py | 8 +++++--- .../flytekit-polars/tests/test_polars_plugin_sd.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 0dfd0c6516..d6e208096e 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -64,11 +64,13 @@ def decode( ) -> pl.DataFrame: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" + # Polars doesn't know the path is directory without "/*" + # https://github.com/pola-rs/polars/blob/5f3e332fb2a653064f083b02949c527e0ec0afda/py-polars/polars/internals/dataframe/frame.py#L649 + local_dir = f"{local_dir}/*" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(path, columns=columns) - return pl.read_parquet(path) + return pl.read_parquet(local_dir, columns=columns) + return pl.read_parquet(local_dir) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index b991cd5d13..f4308bf229 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -66,3 +66,16 @@ def test_polars_renderer(): assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame( df.describe().transpose(), columns=df.describe().columns ).to_html(index=False) + + +def test_parquet_to_polars(): + data = {"name": ["Alice"], "age": [5]} + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame(data=data) + return StructuredDataset(dataframe=df) + + sd = create_sd() + polars_df = sd.open(pl.DataFrame).all() + assert polars_df == pl.DataFrame(data) From 6c3fa3c039fb793115a70170a446df6a60a939b2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 13:21:39 -0800 Subject: [PATCH 2/7] Read polars dataframe in a folder Signed-off-by: Kevin Su --- .../flytekitplugins/huggingface/sd_transformers.py | 5 ++--- .../flytekitplugins/polars/sd_transformers.py | 6 ++---- .../flytekit-spark/flytekitplugins/spark/sd_transformers.py | 3 ++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 0690179bb1..919119dd33 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -59,12 +59,11 @@ def decode( ) -> datasets.Dataset: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - path = f"{local_dir}/00000" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return datasets.Dataset.from_parquet(path, columns=columns) - return datasets.Dataset.from_parquet(path) + return datasets.Dataset.from_parquet(local_dir, columns=columns) + return datasets.Dataset.from_parquet(local_dir) StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index d6e208096e..7b71f627f9 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -64,13 +64,11 @@ def decode( ) -> pl.DataFrame: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - # Polars doesn't know the path is directory without "/*" - # https://github.com/pola-rs/polars/blob/5f3e332fb2a653064f083b02949c527e0ec0afda/py-polars/polars/internals/dataframe/frame.py#L649 - local_dir = f"{local_dir}/*" if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pl.read_parquet(local_dir, columns=columns) - return pl.read_parquet(local_dir) + print("local_dir local_dir", local_dir) + return pl.read_parquet(local_dir, use_pyarrow=True) StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 46079f40dd..76e370025a 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -38,7 +38,8 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - df.write.mode("overwrite").parquet(path) + sc = ctx.user_space_params.spark_session.sparkContext + df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) From 76552d1c2de6a4e8c18d2d1f0746ab13278056ad Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 14:08:05 -0800 Subject: [PATCH 3/7] Load huggingface and spark plugin implicitly Signed-off-by: Kevin Su --- .../flytekitplugins/huggingface/sd_transformers.py | 7 ++++--- plugins/flytekit-huggingface/setup.py | 1 + .../tests/test_huggingface_plugin_sd.py | 12 ++++++++++++ .../flytekitplugins/polars/sd_transformers.py | 1 - .../flytekitplugins/spark/sd_transformers.py | 2 ++ plugins/flytekit-spark/setup.py | 1 + 6 files changed, 20 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py index 919119dd33..579efd366c 100644 --- a/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py +++ b/plugins/flytekit-huggingface/flytekitplugins/huggingface/sd_transformers.py @@ -1,3 +1,4 @@ +import os import typing import datasets @@ -59,11 +60,11 @@ def decode( ) -> datasets.Dataset: local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) - + files = [item.path for item in os.scandir(local_dir)] if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return datasets.Dataset.from_parquet(local_dir, columns=columns) - return datasets.Dataset.from_parquet(local_dir) + return datasets.Dataset.from_parquet(files, columns=columns) + return datasets.Dataset.from_parquet(files) StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler()) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 7ce3ac2c1a..22cb096ba8 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -38,4 +38,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py index 170fdc3789..5b65b2511c 100644 --- a/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py +++ b/plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py @@ -68,3 +68,15 @@ def test_datasets_renderer(): df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")}) dataset = datasets.Dataset.from_pandas(df) assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "
") + + +def test_parquet_to_datasets(): + df = pd.DataFrame({"name": ["Alice"], "age": [10]}) + + @task + def create_sd() -> StructuredDataset: + return StructuredDataset(dataframe=df) + + sd = create_sd() + dataset = sd.open(datasets.Dataset).all() + assert dataset.data == datasets.Dataset.from_pandas(df).data diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 7b71f627f9..d20f33a1ab 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -67,7 +67,6 @@ def decode( if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pl.read_parquet(local_dir, columns=columns) - print("local_dir local_dir", local_dir) return pl.read_parquet(local_dir, use_pyarrow=True) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 76e370025a..8cd914c05f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -39,6 +39,8 @@ def encode( path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) sc = ctx.user_space_params.spark_session.sparkContext + # Avoid generating SUCCESS files + sc._jsc.hadoopConfiguration().set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index d344eaa2ba..67d47cf6b1 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -34,4 +34,5 @@ "Topic :: Software Development :: Libraries :: Python Modules", ], scripts=["scripts/flytekit_install_spark3.sh"], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) From 5c52efe1f851ff5a5cb8f3007096ff47419c315c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 15:09:30 -0800 Subject: [PATCH 4/7] nit Signed-off-by: Kevin Su --- .../flytekit-spark/flytekitplugins/spark/sd_transformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 8cd914c05f..24f1bebf96 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -3,6 +3,7 @@ import pandas as pd from pyspark.sql.dataframe import DataFrame +import flytekit from flytekit import FlyteContext from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -38,7 +39,7 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - sc = ctx.user_space_params.spark_session.sparkContext + sc = flytekit.current_context().spark_session # Avoid generating SUCCESS files sc._jsc.hadoopConfiguration().set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") df.write.mode("overwrite").parquet(path=path) From 8cc88703f47ff049232f11190066a9fb7506264d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 16:53:39 -0800 Subject: [PATCH 5/7] nit Signed-off-by: Kevin Su --- .../flytekit-spark/flytekitplugins/spark/sd_transformers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 24f1bebf96..42819b414c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,9 +1,9 @@ import typing import pandas as pd +import pyspark as _pyspark from pyspark.sql.dataframe import DataFrame -import flytekit from flytekit import FlyteContext from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -39,9 +39,9 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - sc = flytekit.current_context().spark_session + ss = _pyspark.sql.SparkSession.builder.getOrCreate() # Avoid generating SUCCESS files - sc._jsc.hadoopConfiguration().set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) From c429084f6d83252ceb8e01053a438cbd3eb84b54 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 6 Jan 2023 22:52:30 -0800 Subject: [PATCH 6/7] Fix tests Signed-off-by: Kevin Su --- .../flytekit-polars/flytekitplugins/polars/sd_transformers.py | 2 +- plugins/flytekit-polars/tests/test_polars_plugin_sd.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index d20f33a1ab..0b5bf8e577 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -66,7 +66,7 @@ def decode( ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return pl.read_parquet(local_dir, columns=columns) + return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True) return pl.read_parquet(local_dir, use_pyarrow=True) diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index f4308bf229..15a195e5d5 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -78,4 +78,4 @@ def create_sd() -> StructuredDataset: sd = create_sd() polars_df = sd.open(pl.DataFrame).all() - assert polars_df == pl.DataFrame(data) + assert pl.DataFrame(data).frame_equal(polars_df) From 897115259332e817e585fccf1620572ebe8119bf Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 9 Jan 2023 15:50:53 -0800 Subject: [PATCH 7/7] remove _pyspark alias Signed-off-by: Yee Hing Tong --- .../flytekit-spark/flytekitplugins/spark/sd_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 42819b414c..386570be5c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,7 +1,7 @@ import typing import pandas as pd -import pyspark as _pyspark +import pyspark from pyspark.sql.dataframe import DataFrame from flytekit import FlyteContext @@ -39,7 +39,7 @@ def encode( ) -> literals.StructuredDataset: path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory() df = typing.cast(DataFrame, structured_dataset.dataframe) - ss = _pyspark.sql.SparkSession.builder.getOrCreate() + ss = pyspark.sql.SparkSession.builder.getOrCreate() # Avoid generating SUCCESS files ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") df.write.mode("overwrite").parquet(path=path)