From c7d69fa72acaced90955625715838aae2719abf0 Mon Sep 17 00:00:00 2001 From: peridotml <106936600+peridotml@users.noreply.github.com> Date: Wed, 19 Apr 2023 14:38:27 -0700 Subject: [PATCH] add ImageRender class for Flyte Decks (#1599) Signed-off-by: esad Signed-off-by: Fabio Graetz --- .../flytekitplugins/deck/renderer.py | 42 +++++++++++++++++-- .../tests/test_renderer.py | 34 ++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index 3c70fb8f60..55d835efb7 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -1,8 +1,15 @@ +import base64 +from io import BytesIO +from typing import Union + import markdown -import pandas +import pandas as pd import plotly.express as px +from PIL import Image from ydata_profiling import ProfileReport +from flytekit.types.file import FlyteFile + class FrameProfilingRenderer: """ @@ -12,8 +19,8 @@ class FrameProfilingRenderer: def __init__(self, title: str = "Pandas Profiling Report"): self._title = title - def to_html(self, df: pandas.DataFrame) -> str: - assert isinstance(df, pandas.DataFrame) + def to_html(self, df: pd.DataFrame) -> str: + assert isinstance(df, pd.DataFrame) profile = ProfileReport(df, title=self._title) return profile.to_html() @@ -45,6 +52,33 @@ class BoxRenderer: def __init__(self, column_name): self._column_name = column_name - def to_html(self, df: pandas.DataFrame) -> str: + def to_html(self, df: pd.DataFrame) -> str: fig = px.box(df, y=self._column_name) return fig.to_html() + + +class ImageRenderer: + """Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data + represented as a base64-encoded string. + """ + + def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str: + img = cls._get_image_object(image_src) + return cls._image_to_html_string(img) + + @staticmethod + def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image: + if isinstance(image_src, FlyteFile): + local_path = image_src.download() + return Image.open(local_path) + elif isinstance(image_src, Image.Image): + return image_src + else: + raise ValueError("Unsupported image source type") + + @staticmethod + def _image_to_html_string(img: Image.Image) -> str: + buffered = BytesIO() + img.save(buffered, format="PNG") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + return f'Rendered Image' diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py index 8761fb6655..900543dd0b 100644 --- a/plugins/flytekit-deck-standard/tests/test_renderer.py +++ b/plugins/flytekit-deck-standard/tests/test_renderer.py @@ -1,6 +1,12 @@ +import tempfile + import markdown import pandas as pd -from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer +import pytest +from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer +from PIL import Image + +from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]}) @@ -19,3 +25,29 @@ def test_markdown_renderer(): def test_box_renderer(): renderer = BoxRenderer("Name") assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title() + + +def create_simple_image(fmt: str): + """Create a simple PNG image using PIL""" + img = Image.new("RGB", (100, 100), color="black") + tmp = tempfile.mktemp() + img.save(tmp, fmt) + return tmp + + +png_image = create_simple_image(fmt="png") +jpeg_image = create_simple_image(fmt="jpeg") + + +@pytest.mark.parametrize( + "image_src", + [ + FlyteFile(path=png_image), + JPEGImageFile(path=jpeg_image), + PNGImageFile(path=png_image), + Image.open(png_image), + ], +) +def test_image_renderer(image_src): + renderer = ImageRenderer() + assert "