Skip to content

Commit

Permalink
Add polars plugin (#1061)
Browse files Browse the repository at this point in the history
* add polars plugin

Signed-off-by: Robin Kahlow <[email protected]>

* support for older polars versions, add info about what polars is

Signed-off-by: Robin Kahlow <[email protected]>

* run make fmt

Signed-off-by: Robin Kahlow <[email protected]>

* structured dataset instead of schema transformer

Signed-off-by: Robin Kahlow <[email protected]>

* polars html describe only

Signed-off-by: Robin Kahlow <[email protected]>

* set polars min to 0.7.13 (.describe() added)

Signed-off-by: Robin Kahlow <[email protected]>

* set polars min to 0.8.27 (.transpose() added)

Signed-off-by: Robin Kahlow <[email protected]>

* add gcs, fix encode local/remote dir

Signed-off-by: Robin Kahlow <[email protected]>

* black and isort

Signed-off-by: Robin Kahlow <[email protected]>

* add polars plugin to pythonbuild.yml

Signed-off-by: Robin Kahlow <[email protected]>
  • Loading branch information
RobinKa authored Jun 15, 2022
1 parent 2b4b15a commit 87d1390
Show file tree
Hide file tree
Showing 10 changed files with 393 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
- flytekit-modin
- flytekit-pandera
- flytekit-papermill
- flytekit-polars
- flytekit-snowflake
- flytekit-spark
- flytekit-sqlalchemy
Expand Down
5 changes: 5 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

if importlib.util.find_spec("pyspark") is not None:
import pyspark
if importlib.util.find_spec("polars") is not None:
import polars as pl
from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin
Expand Down Expand Up @@ -647,6 +649,9 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ
return pd.DataFrame(df).describe().to_html()
elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame):
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()
elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame):
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)
else:
raise NotImplementedError("Conversion to html string should be implemented")

Expand Down
10 changes: 10 additions & 0 deletions plugins/flytekit-polars/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Flytekit Polars Plugin
[Polars](https://github.com/pola-rs/polars) is a blazingly fast DataFrames library implemented in Rust using Apache Arrow Columnar Format as memory model.

This plugin supports `polars.DataFrame` as a data type with [StructuredDataset](https://docs.flyte.org/projects/cookbook/en/latest/auto/core/type_system/structured_dataset.html).

To install the plugin, run the following command:

```bash
pip install flytekitplugins-polars
```
14 changes: 14 additions & 0 deletions plugins/flytekit-polars/flytekitplugins/polars/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
.. currentmodule:: flytekitplugins.polars
This package contains things that are useful when extending Flytekit.
.. autosummary::
:template: custom.rst
:toctree: generated/
PolarsDataFrameToParquetEncodingHandler
ParquetToPolarsDataFrameDecodingHandler
"""

from .sd_transformers import ParquetToPolarsDataFrameDecodingHandler, PolarsDataFrameToParquetEncodingHandler
73 changes: 73 additions & 0 deletions plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import typing

import polars as pl

from flytekit import FlyteContext
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
GCS,
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
StructuredDatasetTransformerEngine,
)


class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self, protocol: str):
super().__init__(pl.DataFrame, protocol, PARQUET)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
df = typing.cast(pl.DataFrame, structured_dataset.dataframe)

local_dir = ctx.file_access.get_random_local_directory()
local_path = f"{local_dir}/00000"

# Polars 0.13.12 deprecated to_parquet in favor of write_parquet
if hasattr(df, "write_parquet"):
df.write_parquet(local_path)
else:
df.to_parquet(local_path)
remote_dir = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
ctx.file_access.upload_directory(local_dir, remote_dir)
return literals.StructuredDataset(uri=remote_dir, metadata=StructuredDatasetMetadata(structured_dataset_type))


class ParquetToPolarsDataFrameDecodingHandler(StructuredDatasetDecoder):
def __init__(self, protocol: str):
super().__init__(pl.DataFrame, protocol, PARQUET)

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pl.DataFrame:
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(flyte_value.uri, local_dir, is_multipart=True)
path = f"{local_dir}/00000"
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return pl.read_parquet(path, columns=columns)
return pl.read_parquet(path)


for protocol in [LOCAL, S3]:
StructuredDatasetTransformerEngine.register(
PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=True
)
StructuredDatasetTransformerEngine.register(
ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=True
)
StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler(GCS), default_for_type=False)
2 changes: 2 additions & 0 deletions plugins/flytekit-polars/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.
-e file:.#egg=flytekitplugins-polars
186 changes: 186 additions & 0 deletions plugins/flytekit-polars/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile requirements.in
#
-e file:.#egg=flytekitplugins-polars
# via -r requirements.in
arrow==1.2.2
# via jinja2-time
binaryornot==0.4.4
# via cookiecutter
certifi==2021.10.8
# via requests
cffi==1.15.0
# via cryptography
chardet==4.0.0
# via binaryornot
charset-normalizer==2.0.12
# via requests
click==8.1.2
# via
# cookiecutter
# flytekit
cloudpickle==2.0.0
# via flytekit
cookiecutter==1.7.3
# via flytekit
croniter==1.3.4
# via flytekit
cryptography==36.0.2
# via secretstorage
dataclasses-json==0.5.7
# via flytekit
decorator==5.1.1
# via retry
deprecated==1.2.13
# via flytekit
diskcache==5.4.0
# via flytekit
docker==5.0.3
# via flytekit
docker-image-py==0.1.12
# via flytekit
docstring-parser==0.13
# via flytekit
flyteidl==1.0.1
# via flytekit
flytekit==1.1.0b2
# via flytekitplugins-polars
googleapis-common-protos==1.56.0
# via
# flyteidl
# grpcio-status
grpcio==1.43.0
# via
# flytekit
# grpcio-status
grpcio-status==1.43.0
# via flytekit
idna==3.3
# via requests
importlib-metadata==4.11.3
# via keyring
jeepney==0.8.0
# via
# keyring
# secretstorage
jinja2==3.1.1
# via
# cookiecutter
# jinja2-time
jinja2-time==0.2.0
# via cookiecutter
keyring==23.5.0
# via flytekit
markupsafe==2.1.1
# via jinja2
marshmallow==3.15.0
# via
# dataclasses-json
# marshmallow-enum
# marshmallow-jsonschema
marshmallow-enum==1.5.1
# via dataclasses-json
marshmallow-jsonschema==0.13.0
# via flytekit
mypy-extensions==0.4.3
# via typing-inspect
natsort==8.1.0
# via flytekit
numpy==1.22.3
# via
# pandas
# polars
# pyarrow
packaging==21.3
# via marshmallow
pandas==1.4.1
# via flytekit
polars==0.13.44
# via flytekitplugins-polars
poyo==0.5.0
# via cookiecutter
protobuf==3.20.1
# via
# flyteidl
# flytekit
# googleapis-common-protos
# grpcio-status
# protoc-gen-swagger
protoc-gen-swagger==0.1.0
# via flyteidl
py==1.11.0
# via retry
pyarrow==6.0.1
# via flytekit
pycparser==2.21
# via cffi
pyparsing==3.0.8
# via packaging
python-dateutil==2.8.2
# via
# arrow
# croniter
# flytekit
# pandas
python-json-logger==2.0.2
# via flytekit
python-slugify==6.1.1
# via cookiecutter
pytimeparse==1.1.8
# via flytekit
pytz==2022.1
# via
# flytekit
# pandas
pyyaml==6.0
# via flytekit
regex==2022.3.15
# via docker-image-py
requests==2.27.1
# via
# cookiecutter
# docker
# flytekit
# responses
responses==0.20.0
# via flytekit
retry==0.9.2
# via flytekit
secretstorage==3.3.2
# via keyring
six==1.16.0
# via
# cookiecutter
# grpcio
# python-dateutil
sortedcontainers==2.4.0
# via flytekit
statsd==3.3.0
# via flytekit
text-unidecode==1.3
# via python-slugify
typing-extensions==4.2.0
# via
# flytekit
# polars
# typing-inspect
typing-inspect==0.7.1
# via dataclasses-json
urllib3==1.26.9
# via
# flytekit
# requests
# responses
websocket-client==1.3.2
# via docker
wheel==0.37.1
# via flytekit
wrapt==1.14.0
# via
# deprecated
# flytekit
zipp==3.8.0
# via importlib-metadata
38 changes: 38 additions & 0 deletions plugins/flytekit-polars/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from setuptools import setup

PLUGIN_NAME = "polars"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = [
"flytekit>=1.1.0b0,<1.2.0",
"polars>=0.8.27",
]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="Robin Kahlow",
description="Polars plugin for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.7",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
)
Empty file.
Loading

0 comments on commit 87d1390

Please sign in to comment.