Skip to content

Commit

Permalink
Read polar dataframe without copying to local (#1618)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and eapolinario committed May 16, 2023
1 parent 33a2f32 commit b2fb4bf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.basic_dfs import get_storage_options
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -62,12 +63,12 @@ def decode(
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pl.DataFrame:
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)
uri = flyte_value.uri
kwargs = get_storage_options(ctx.file_access.data_config, uri)
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pl.read_parquet(local_dir, columns=columns, use_pyarrow=True)
return pl.read_parquet(local_dir, use_pyarrow=True)
return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs)
return pl.read_parquet(uri, use_pyarrow=True, storage_options=kwargs)


StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler())
Expand Down
12 changes: 12 additions & 0 deletions plugins/flytekit-polars/tests/test_polars_plugin_sd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import pandas as pd
import polars as pl
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
Expand Down Expand Up @@ -79,3 +81,13 @@ def create_sd() -> StructuredDataset:
sd = create_sd()
polars_df = sd.open(pl.DataFrame).all()
assert pl.DataFrame(data).frame_equal(polars_df)

tmp = tempfile.mktemp()
pl.DataFrame(data).write_parquet(tmp)

@task
def t1(sd: StructuredDataset) -> pl.DataFrame:
return sd.open(pd.DataFrame).all()

sd = StructuredDataset(uri=tmp)
t1(sd=sd).frame_equal(polars_df)

0 comments on commit b2fb4bf

Please sign in to comment.