Skip to content

Commit

Permalink
Register dataframe renderers in structured dataset (#1140)
Browse files Browse the repository at this point in the history
* Register dataframe renderers in structured dataset

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix test

Signed-off-by: Kevin Su <[email protected]>

* more tests

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Aug 26, 2022
1 parent e4a1514 commit 3cf0639
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 28 deletions.
11 changes: 11 additions & 0 deletions flytekit/deck/renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

import pandas
import pyarrow
from typing_extensions import Protocol, runtime_checkable


Expand All @@ -24,3 +25,13 @@ def __init__(self, max_rows: Optional[int] = None):
def to_html(self, df: pandas.DataFrame) -> str:
assert isinstance(df, pandas.DataFrame)
return df.to_html(max_rows=self._max_rows)


class ArrowRenderer:
"""
Render a Arrow dataframe as an HTML table.
"""

def to_html(self, df: pyarrow.Table) -> str:
assert isinstance(df, pyarrow.Table)
return df.to_string()
5 changes: 5 additions & 0 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pyarrow.parquet as pq

from flytekit import FlyteContext
from flytekit.deck import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
Expand Down Expand Up @@ -103,3 +105,6 @@ def decode(
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler())
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler())

StructuredDatasetTransformerEngine.register_renderer(pd.DataFrame, TopFrameRenderer())
StructuredDatasetTransformerEngine.register_renderer(pa.Table, ArrowRenderer())
31 changes: 10 additions & 21 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections
import importlib
import os
import types
import typing
Expand All @@ -13,20 +12,14 @@
import numpy as _np
import pandas as pd
import pyarrow as pa

from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence

if importlib.util.find_spec("pyspark") is not None:
import pyspark
if importlib.util.find_spec("polars") is not None:
import polars as pl

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, TypeAlias, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import DataPersistencePlugins, DiskPersistence
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.deck.renderer import Renderable
from flytekit.loggers import logger
from flytekit.models import literals
from flytekit.models import types as type_models
Expand Down Expand Up @@ -339,6 +332,7 @@ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
DEFAULT_FORMATS: Dict[Type, str] = {}

Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]
Renderers: Dict[Type, Renderable] = {}

@staticmethod
def _finder(handler_map, df_type: Type, protocol: str, format: str):
Expand Down Expand Up @@ -385,6 +379,10 @@ def __init__(self):
# Instances of StructuredDataset opt-in to the ability of being cached.
self._hash_overridable = True

@classmethod
def register_renderer(cls, python_type: Type, renderer: Renderable):
cls.Renderers[python_type] = renderer

@classmethod
def register(cls, h: Handlers, default_for_type: Optional[bool] = False, override: Optional[bool] = False):
"""
Expand Down Expand Up @@ -698,19 +696,10 @@ def to_html(self, ctx: FlyteContext, python_val: typing.Any, expected_python_typ
else:
df = python_val

if isinstance(df, pd.DataFrame):
return df.describe().to_html()
elif isinstance(df, pa.Table):
return df.to_string()
elif isinstance(df, _np.ndarray):
return pd.DataFrame(df).describe().to_html()
elif importlib.util.find_spec("pyspark") is not None and isinstance(df, pyspark.sql.DataFrame):
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()
elif importlib.util.find_spec("polars") is not None and isinstance(df, pl.DataFrame):
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)
if type(df) in self.Renderers:
return self.Renderers[type(df)].to_html(df)
else:
raise NotImplementedError("Conversion to html string should be implemented")
raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}")

def open_as(
self,
Expand Down
13 changes: 13 additions & 0 deletions plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

import pandas as pd
import polars as pl

from flytekit import FlyteContext
Expand All @@ -15,6 +16,17 @@
)


class PolarsDataFrameRenderer:
"""
The Polars DataFrame summary statistics are rendered as an HTML table.
"""

def to_html(self, df: pl.DataFrame) -> str:
assert isinstance(df, pl.DataFrame)
describe_df = df.describe()
return pd.DataFrame(describe_df.transpose(), columns=describe_df.columns).to_html(index=False)


class PolarsDataFrameToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pl.DataFrame, None, PARQUET)
Expand Down Expand Up @@ -61,3 +73,4 @@ def decode(

StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(pl.DataFrame, PolarsDataFrameRenderer())
9 changes: 9 additions & 0 deletions plugins/flytekit-polars/tests/test_polars_plugin_sd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pandas as pd
import polars as pl
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
Expand Down Expand Up @@ -57,3 +59,10 @@ def wf() -> full_schema:

result = wf()
assert result is not None


def test_polars_renderer():
df = pl.DataFrame({"col1": [1, 3, 2], "col2": list("abc")})
assert PolarsDataFrameRenderer().to_html(df) == pd.DataFrame(
df.describe().transpose(), columns=df.describe().columns
).to_html(index=False)
12 changes: 12 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

import pandas as pd
from pyspark.sql.dataframe import DataFrame

from flytekit import FlyteContext
Expand All @@ -15,6 +16,16 @@
)


class SparkDataFrameRenderer:
"""
Render a Spark dataframe schema as an HTML table.
"""

def to_html(self, df: DataFrame) -> str:
assert isinstance(df, DataFrame)
return pd.DataFrame(df.schema, columns=["StructField"]).to_html()


class SparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(DataFrame, None, PARQUET)
Expand Down Expand Up @@ -50,3 +61,4 @@ def decode(

StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())
15 changes: 15 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,18 @@ def test_protocol_detection():

protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo")
assert protocol == "bq"


def test_register_renderers():
class DummyRenderer:
def to_html(self, input: str) -> str:
return "hello " + input

renderers = StructuredDatasetTransformerEngine.Renderers
StructuredDatasetTransformerEngine.register_renderer(str, DummyRenderer())
assert renderers[str].to_html("flyte") == "hello flyte"
assert pd.DataFrame in renderers
assert pa.Table in renderers

with pytest.raises(NotImplementedError, match="Could not find a renderer for <class 'int'> in"):
StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int)
11 changes: 7 additions & 4 deletions tests/flytekit/unit/deck/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pandas as pd
import pyarrow as pa

from flytekit.deck.renderer import TopFrameRenderer
from flytekit.deck.renderer import ArrowRenderer, TopFrameRenderer


def test_frame_profiling_renderer():
def test_renderer():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})
renderer = TopFrameRenderer()
assert renderer.to_html(df) == df.to_html()
pa_df = pa.Table.from_pandas(df)

assert TopFrameRenderer().to_html(df) == df.to_html()
assert ArrowRenderer().to_html(pa_df) == pa_df.to_string()
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ def decode(
StructuredDatasetTransformerEngine.register(MockBQDecodingHandlers(), False, True)


class NumpyRenderer:
"""
The Polars DataFrame summary statistics are rendered as an HTML table.
"""

def to_html(self, array: np.ndarray) -> str:
return pd.DataFrame(array).describe().to_html()


@pytest.fixture(autouse=True)
def numpy_type():
class NumpyEncodingHandlers(StructuredDatasetEncoder):
Expand Down Expand Up @@ -101,9 +110,9 @@ def decode(
table = pq.read_table(local_dir)
return table.to_pandas().to_numpy()

for protocol in ["/", "s3"]:
StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray, protocol, PARQUET))
StructuredDatasetTransformerEngine.register(NumpyEncodingHandlers(np.ndarray))
StructuredDatasetTransformerEngine.register(NumpyDecodingHandlers(np.ndarray))
StructuredDatasetTransformerEngine.register_renderer(np.ndarray, NumpyRenderer())


@task
Expand Down

0 comments on commit 3cf0639

Please sign in to comment.