Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flytekitplugin pandera update: use entrypoint and structured dataset #2821

Merged
merged 11 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
:template: custom.rst
:toctree: generated/

PanderaTransformer
PanderaPandasTransformer
PandasReportRenderer
ValidationConfig
"""

from .schema import PanderaTransformer
from .config import ValidationConfig
from .pandas_renderer import PandasReportRenderer
from .pandas_transformer import PanderaPandasTransformer
10 changes: 10 additions & 0 deletions plugins/flytekit-pandera/flytekitplugins/pandera/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Pandera validation configuration."""

from dataclasses import dataclass
from typing import Literal


@dataclass
class ValidationConfig:
# determine how to handle validation errors in the Flyte type transformer
on_error: Literal["raise", "warn"] = "raise"
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import TYPE_CHECKING

from flytekit import lazy_module

if TYPE_CHECKING:
import pandas

import pandera
else:
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")


class PandasReportRenderer:
def __init__(self, title: str = "Pandera Error Report"):
self._title = title

def to_html(self, error: pandera.errors.SchemaErrors) -> str:
error.failure_cases.groupby(["schema_context"])
html = (
error.failure_cases.set_index(["schema_context", "column", "check"])
.drop(["check_number"], axis="columns")[["index", "failure_case"]]
.to_html()
)
return html
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import typing
import warnings
from typing import TYPE_CHECKING, Type, Union

from flytekit import Deck, FlyteContext, lazy_module
from flytekit.extend import TypeEngine, TypeTransformer
from flytekit.models.literals import Literal
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.structured import StructuredDataset
from flytekit.types.structured.structured_dataset import StructuredDatasetTransformerEngine, get_supported_types

from .config import ValidationConfig
from .pandas_renderer import PandasReportRenderer

if TYPE_CHECKING:
import pandas

import pandera
else:
pandas = lazy_module("pandas")
pandera = lazy_module("pandera")


T = typing.TypeVar("T")


class PanderaPandasTransformer(TypeTransformer[pandera.typing.DataFrame]):
_SUPPORTED_TYPES: typing.Dict[type, SchemaType.SchemaColumn.SchemaColumnType] = get_supported_types()
_VALIDATION_MEMO = set()

def __init__(self):
super().__init__("Pandera Transformer", pandera.typing.DataFrame) # type: ignore
self._sd_transformer = StructuredDatasetTransformerEngine()

def _get_pandera_schema(self, t: Type[pandera.typing.DataFrame]):
config = ValidationConfig()
if typing.get_origin(t) is typing.Annotated:
t, *args = typing.get_args(t)
# get pandera config
for arg in args:
if isinstance(arg, ValidationConfig):
config = arg
break

try:
type_args = typing.get_args(t)
except AttributeError:
# for python < 3.8
type_args = getattr(t, "__args__", None)
eapolinario marked this conversation as resolved.
Show resolved Hide resolved

if type_args:
schema_model, *_ = type_args
schema = schema_model.to_schema()
else:
schema = pandera.DataFrameSchema() # type: ignore
return schema, config

@staticmethod
def _get_pandas_type(pandera_dtype: pandera.dtypes.DataType):
return pandera_dtype.type.type

def _get_col_dtypes(self, t: Type[pandera.typing.DataFrame]):
schema, _ = self._get_pandera_schema(t)
return {k: self._get_pandas_type(v.dtype) for k, v in schema.columns.items()}

def get_literal_type(self, t: Type[pandera.typing.DataFrame]) -> LiteralType:
if typing.get_origin(t) is typing.Annotated:
t, _ = typing.get_args(t)
return self._sd_transformer.get_literal_type(t)

def assert_type(self, t: Type[T], v: T):
if not hasattr(t, "__origin__") and not isinstance(v, (t, pandas.DataFrame)):
raise TypeError(f"Type of Val '{v}' is not an instance of {t}")

def to_literal(
self,
ctx: FlyteContext,
python_val: Union[pandas.DataFrame, StructuredDataset],
python_type: Type[pandera.typing.DataFrame],
expected: LiteralType,
) -> Literal:
assert isinstance(
python_val, (pandas.DataFrame, StructuredDataset)
), f"Only Pandas Dataframe object can be returned from a task, returned object type {type(python_val)}"

if isinstance(python_val, StructuredDataset):
lv = self._sd_transformer.to_literal(ctx, python_val, pandas.DataFrame, expected)
python_val = self._sd_transformer.to_python_value(ctx, lv, pandas.DataFrame)

schema, config = self._get_pandera_schema(python_type)
renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}")
try:
val = schema.validate(python_val, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
if config.on_error == "raise":
raise exc
elif config.on_error == "warn":
warnings.warn(str(exc), RuntimeWarning)
html = renderer.to_html(exc)
Deck(renderer._title, html)
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
val = python_val

lv = self._sd_transformer.to_literal(ctx, val, pandas.DataFrame, expected)

# In cases where a task is being called locally, this method will be invoked to convert the python input value
# to a Flyte literal, which is then deserialized back to a python value. In such cases, we can cache the
# structured dataset uri and schema name to avoid repeating the validation process in the subsequent
# to_python_value call.
self._VALIDATION_MEMO.add((lv.scalar.structured_dataset.uri, schema.name))
return lv

def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[pandas.DataFrame]
) -> pandera.typing.DataFrame:
if not (lv and lv.scalar and lv.scalar.structured_dataset):
raise AssertionError("Can only convert a literal structured dataset to a pandera schema")

df = self._sd_transformer.to_python_value(ctx, lv, pandas.DataFrame)
schema, config = self._get_pandera_schema(expected_python_type)

if (lv.scalar.structured_dataset.uri, schema.name) in self._VALIDATION_MEMO:
return df

renderer = PandasReportRenderer(title=f"Pandera Report: {schema.name}")
try:
val = schema.validate(df, lazy=True)
except (pandera.errors.SchemaError, pandera.errors.SchemaErrors) as exc:
if config.on_error == "raise":
raise exc
elif config.on_error == "warn":
print(ctx.execution_state)
warnings.warn(str(exc), RuntimeWarning)
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
html = renderer.to_html(exc)
Deck(renderer._title, html)
else:
raise ValueError(f"Invalid on_error value: {config.on_error}")
val = df

return val


TypeEngine.register(PanderaPandasTransformer())
101 changes: 0 additions & 101 deletions plugins/flytekit-pandera/flytekitplugins/pandera/schema.py

This file was deleted.

3 changes: 2 additions & 1 deletion plugins/flytekit-pandera/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pandera>=0.7.1", "pandas"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pandera>=0.7.1", "pandas", "great_tables"]

__version__ = "0.0.0+develop"

Expand Down Expand Up @@ -32,4 +32,5 @@
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
Loading
Loading