diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index d2b44e0b65..55d835efb7 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -1,7 +1,14 @@ +import base64 +from io import BytesIO +from typing import Union + import markdown -import pandas +import pandas as pd import plotly.express as px -from pandas_profiling import ProfileReport +from PIL import Image +from ydata_profiling import ProfileReport + +from flytekit.types.file import FlyteFile class FrameProfilingRenderer: @@ -12,8 +19,8 @@ class FrameProfilingRenderer: def __init__(self, title: str = "Pandas Profiling Report"): self._title = title - def to_html(self, df: pandas.DataFrame) -> str: - assert isinstance(df, pandas.DataFrame) + def to_html(self, df: pd.DataFrame) -> str: + assert isinstance(df, pd.DataFrame) profile = ProfileReport(df, title=self._title) return profile.to_html() @@ -45,6 +52,33 @@ class BoxRenderer: def __init__(self, column_name): self._column_name = column_name - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: pd.DataFrame) -> str: fig = px.box(df, y=self._column_name) return fig.to_html() + + +class ImageRenderer: + """Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data + represented as a base64-encoded string. + """ + + def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str: + img = cls._get_image_object(image_src) + return cls._image_to_html_string(img) + + @staticmethod + def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image: + if isinstance(image_src, FlyteFile): + local_path = image_src.download() + return Image.open(local_path) + elif isinstance(image_src, Image.Image): + return image_src + else: + raise ValueError("Unsupported image source type") + + @staticmethod + def _image_to_html_string(img: Image.Image) -> str: + buffered = BytesIO() + img.save(buffered, format="PNG") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + return f'Rendered Image' diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py index 79eb7e877d..0544acff5d 100644 --- a/plugins/flytekit-deck-standard/tests/test_renderer.py +++ b/plugins/flytekit-deck-standard/tests/test_renderer.py @@ -1,6 +1,12 @@ +import tempfile + import markdown import pandas as pd -from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer +import pytest +from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer +from PIL import Image + +from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) @@ -19,3 +25,29 @@ def test_markdown_renderer(): def test_box_renderer(): renderer = BoxRenderer("Name") assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title() + + +def create_simple_image(fmt: str): + """Create a simple PNG image using PIL""" + img = Image.new("RGB", (100, 100), color="black") + tmp = tempfile.mktemp() + img.save(tmp, fmt) + return tmp + + +png_image = create_simple_image(fmt="png") +jpeg_image = create_simple_image(fmt="jpeg") + + +@pytest.mark.parametrize( + "image_src", + [ + FlyteFile(path=png_image), + JPEGImageFile(path=jpeg_image), + PNGImageFile(path=png_image), + Image.open(png_image), + ], +) +def test_image_renderer(image_src): + renderer = ImageRenderer() + assert "