Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ImageRender class for Flyte Decks #1599

Merged
merged 1 commit into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import base64
from io import BytesIO
from typing import Union

import markdown
import pandas
import pandas as pd
import plotly.express as px
from PIL import Image
from ydata_profiling import ProfileReport

from flytekit.types.file import FlyteFile


class FrameProfilingRenderer:
"""
Expand All @@ -12,8 +19,8 @@ class FrameProfilingRenderer:
def __init__(self, title: str = "Pandas Profiling Report"):
self._title = title

def to_html(self, df: pandas.DataFrame) -> str:
assert isinstance(df, pandas.DataFrame)
def to_html(self, df: pd.DataFrame) -> str:
assert isinstance(df, pd.DataFrame)
profile = ProfileReport(df, title=self._title)
return profile.to_html()

Expand Down Expand Up @@ -45,6 +52,33 @@ class BoxRenderer:
def __init__(self, column_name):
self._column_name = column_name

def to_html(self, df: pandas.DataFrame) -> str:
def to_html(self, df: pd.DataFrame) -> str:
fig = px.box(df, y=self._column_name)
return fig.to_html()


class ImageRenderer:
"""Converts a FlyteFile or PIL.Image.Image object to an HTML string with the image data
represented as a base64-encoded string.
"""

def to_html(cls, image_src: Union[FlyteFile, Image.Image]) -> str:
img = cls._get_image_object(image_src)
return cls._image_to_html_string(img)

@staticmethod
def _get_image_object(image_src: Union[FlyteFile, Image.Image]) -> Image.Image:
if isinstance(image_src, FlyteFile):
local_path = image_src.download()
return Image.open(local_path)
elif isinstance(image_src, Image.Image):
return image_src
else:
raise ValueError("Unsupported image source type")

@staticmethod
def _image_to_html_string(img: Image.Image) -> str:
buffered = BytesIO()
img.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return f'<img src="data:image/png;base64,{img_base64}" alt="Rendered Image" />'
34 changes: 33 additions & 1 deletion plugins/flytekit-deck-standard/tests/test_renderer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import tempfile

import markdown
import pandas as pd
from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer
import pytest
from flytekitplugins.deck.renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer
from PIL import Image

from flytekit.types.file import FlyteFile, JPEGImageFile, PNGImageFile

df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [1, 22]})

Expand All @@ -19,3 +25,29 @@ def test_markdown_renderer():
def test_box_renderer():
renderer = BoxRenderer("Name")
assert "Plotlyconfig = {Mathjaxconfig: 'Local'}" in renderer.to_html(df).title()


def create_simple_image(fmt: str):
"""Create a simple PNG image using PIL"""
img = Image.new("RGB", (100, 100), color="black")
tmp = tempfile.mktemp()
img.save(tmp, fmt)
return tmp


png_image = create_simple_image(fmt="png")
jpeg_image = create_simple_image(fmt="jpeg")


@pytest.mark.parametrize(
"image_src",
[
FlyteFile(path=png_image),
JPEGImageFile(path=jpeg_image),
PNGImageFile(path=png_image),
Image.open(png_image),
],
)
def test_image_renderer(image_src):
renderer = ImageRenderer()
assert "<img" in renderer.to_html(image_src)