Skip to content

Commit

Permalink
Read structured dataset from a folder (#1406)
Browse files Browse the repository at this point in the history
* Read polars dataframe in a folder

Signed-off-by: Kevin Su <[email protected]>

* Read polars dataframe in a folder

Signed-off-by: Kevin Su <[email protected]>

* Load huggingface and spark plugin implicitly

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* Fix tests

Signed-off-by: Kevin Su <[email protected]>

* remove _pyspark alias

Signed-off-by: Yee Hing Tong <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
pingsutw and wild-endeavor authored Jan 10, 2023
1 parent fcf6dce commit 1311ea4
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import typing

import datasets
Expand Down Expand Up @@ -59,12 +60,11 @@ def decode(
) -> datasets.Dataset:
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)
path = f"{local_dir}/00000"

files = [item.path for item in os.scandir(local_dir)]
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return datasets.Dataset.from_parquet(path, columns=columns)
return datasets.Dataset.from_parquet(path)
return datasets.Dataset.from_parquet(files, columns=columns)
return datasets.Dataset.from_parquet(files)


StructuredDatasetTransformerEngine.register(HuggingFaceDatasetToParquetEncodingHandler())
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-huggingface/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
12 changes: 12 additions & 0 deletions plugins/flytekit-huggingface/tests/test_huggingface_plugin_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,15 @@ def test_datasets_renderer():
df = pd.DataFrame({"col1": [1, 3, 2], "col2": list("abc")})
dataset = datasets.Dataset.from_pandas(df)
assert HuggingFaceDatasetRenderer().to_html(dataset) == str(dataset).replace("\n", "<br>")


def test_parquet_to_datasets():
df = pd.DataFrame({"name": ["Alice"], "age": [10]})

@task
def create_sd() -> StructuredDataset:
return StructuredDataset(dataframe=df)

sd = create_sd()
dataset = sd.open(datasets.Dataset).all()
assert dataset.data == datasets.Dataset.from_pandas(df).data
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ def decode(
) -> pl.DataFrame:
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)
path = f"{local_dir}/00000"
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pl.read_parquet(path, columns=columns)
return pl.read_parquet(path)
return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True)
return pl.read_parquet(local_dir, use_pyarrow=True)


StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler())
Expand Down
13 changes: 13 additions & 0 deletions plugins/flytekit-polars/tests/test_polars_plugin_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def test_polars_renderer():
assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame(
df.describe().transpose(), columns=df.describe().columns
).to_html(index=False)


def test_parquet_to_polars():
data = {"name": ["Alice"], "age": [5]}

@task
def create_sd() -> StructuredDataset:
df = pd.DataFrame(data=data)
return StructuredDataset(dataframe=df)

sd = create_sd()
polars_df = sd.open(pl.DataFrame).all()
assert pl.DataFrame(data).frame_equal(polars_df)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import pandas as pd
import pyspark
from pyspark.sql.dataframe import DataFrame

from flytekit import FlyteContext
Expand Down Expand Up @@ -38,7 +39,10 @@ def encode(
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
df = typing.cast(DataFrame, structured_dataset.dataframe)
df.write.mode("overwrite").parquet(path)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
# Avoid generating SUCCESS files
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))


Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-spark/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@
"Topic :: Software Development :: Libraries :: Python Modules",
],
scripts=["scripts/flytekit_install_spark3.sh"],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)

0 comments on commit 1311ea4

Please sign in to comment.