diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
index 3c70fb8f60..55d835efb7 100644
--- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
+++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
@@ -1,8 +1,15 @@
+import base64
+from io import BytesIO
+from typing import Union
+
import markdown
-import pandas
+import pandas as pd
import plotly.express as px
+from PIL import Image
from ydata_profiling import ProfileReport
+from flytekit.types.file import FlyteFile
+
class FrameProfilingRenderer:
"""
@@ -12,8 +19,8 @@ class FrameProfilingRenderer:
def __init__(self, title: str = "Pandas Profiling Report"):
self._title = title
- def to_html(self, df: pandas.DataFrame) -> str:
- assert isinstance(df, pandas.DataFrame)
+ def to_html(self, df: pd.DataFrame) -> str:
+ assert isinstance(df, pd.DataFrame)
profile = ProfileReport(df, title=self._title)
return profile.to_html()
@@ -45,6 +52,33 @@ class BoxRenderer:
def __init__(self, column_name):
self._column_name = column_name
- def to_html(self, df: pandas.DataFrame) -> str:
+ def to_html(self, df: pd.DataFrame) -> str:
fig = px.box(df, y=self._column_name)
return fig.to_html()
+
+
+class ImageRenderer:
+ """Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data
+ represented as a base64-encoded string.
+ """
+
+ def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str:
+ img = cls._get_image_object(image_src)
+ return cls._image_to_html_string(img)
+
+ @staticmethod
+ def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image:
+ if isinstance(image_src, FlyteFile):
+ local_path = image_src.download()
+ return Image.open(local_path)
+ elif isinstance(image_src, Image.Image):
+ return image_src
+ else:
+ raise ValueError("Unsupported image source type")
+
+ @staticmethod
+ def _image_to_html_string(img: Image.Image) -> str:
+ buffered = BytesIO()
+ img.save(buffered, format="PNG")
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
+ return f''
diff --git a/plugins/flytekit-deck-standard/tests/test_renderer.py b/plugins/flytekit-deck-standard/tests/test_renderer.py
index 8761fb6655..900543dd0b 100644
--- a/plugins/flytekit-deck-standard/tests/test_renderer.py
+++ b/plugins/flytekit-deck-standard/tests/test_renderer.py
@@ -1,6 +1,12 @@
+import tempfile
+
import markdown
import pandas as pd
-from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer
+import pytest
+from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer
+from PIL import Image
+
+from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})
@@ -19,3 +25,29 @@ def test_markdown_renderer():
def test_box_renderer():
renderer = BoxRenderer("Name")
assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title()
+
+
+def create_simple_image(fmt: str):
+ """Create a simple PNG image using PIL"""
+ img = Image.new("RGB", (100, 100), color="black")
+ tmp = tempfile.mktemp()
+ img.save(tmp, fmt)
+ return tmp
+
+
+png_image = create_simple_image(fmt="png")
+jpeg_image = create_simple_image(fmt="jpeg")
+
+
+@pytest.mark.parametrize(
+ "image_src",
+ [
+ FlyteFile(path=png_image),
+ JPEGImageFile(path=jpeg_image),
+ PNGImageFile(path=png_image),
+ Image.open(png_image),
+ ],
+)
+def test_image_renderer(image_src):
+ renderer = ImageRenderer()
+ assert "