Skip to content

Commit

Permalink
Merge pull request #1143 from facebookexperimental/feature/base_visua…
Browse files Browse the repository at this point in the history
…lizer

Update base_visualizer.py
  • Loading branch information
sumane81 authored Nov 16, 2024
2 parents 997c7cd + 09e0b7f commit 0178c0b
Showing 1 changed file with 157 additions and 52 deletions.
209 changes: 157 additions & 52 deletions python/src/robyn/visualization/base_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import Dict, Optional, Tuple, Union, Any, List
# pyre-strict

import logging

from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple, Union, List
from pathlib import Path
from IPython.display import Image, display

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
import logging
import base64
import io

# Configure logger
logger = logging.getLogger(__name__)

class BaseVisualizer:

class BaseVisualizer(ABC):
"""
Base class for all Robyn visualization components.
Provides common plotting functionality and styling.
Expand All @@ -22,7 +30,7 @@ def __init__(self, style: str = "bmh"):
style: matplotlib style to use (default: "bmh")
"""
logger.info("Initializing BaseVisualizer with style: %s", style)

# Store style settings
self.style = style
self.default_figsize = (12, 8)
Expand All @@ -41,7 +49,14 @@ def __init__(self, style: str = "bmh"):
logger.debug("Color scheme initialized: %s", self.colors)

# Plot settings
self.font_sizes = {"title": 14, "subtitle": 12, "label": 12, "tick": 10, "annotation": 9, "legend": 10}
self.font_sizes = {
"title": 14,
"subtitle": 12,
"label": 12,
"tick": 10,
"annotation": 9,
"legend": 10,
}
logger.debug("Font sizes configured: %s", self.font_sizes)

# Default alpha values
Expand Down Expand Up @@ -99,7 +114,9 @@ def create_figure(
logger.debug("Using figure size: %s", figsize)

try:
self.current_figure, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
self.current_figure, axes = plt.subplots(
nrows=nrows, ncols=ncols, figsize=figsize
)
if nrows == ncols == 1:
self.current_axes = axes
else:
Expand Down Expand Up @@ -136,8 +153,13 @@ def setup_axis(
yticklabels: Optional list of y-axis tick labels
rotation: Rotation angle for tick labels
"""
logger.debug("Setting up axis with title: %s, xlabel: %s, ylabel: %s", title, xlabel, ylabel)

logger.debug(
"Setting up axis with title: %s, xlabel: %s, ylabel: %s",
title,
xlabel,
ylabel,
)

try:
if title:
ax.set_title(title, fontsize=self.font_sizes["title"])
Expand All @@ -155,7 +177,9 @@ def setup_axis(

if xticklabels is not None:
logger.debug("Setting x-tick labels with rotation: %d", rotation)
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=self.font_sizes["tick"])
ax.set_xticklabels(
xticklabels, rotation=rotation, fontsize=self.font_sizes["tick"]
)
if yticklabels is not None:
logger.debug("Setting y-tick labels")
ax.set_yticklabels(yticklabels, fontsize=self.font_sizes["tick"])
Expand All @@ -168,7 +192,13 @@ def setup_axis(
raise

def add_percentage_annotation(
self, ax: plt.Axes, x: float, y: float, percentage: float, va: str = "bottom", ha: str = "center"
self,
ax: plt.Axes,
x: float,
y: float,
percentage: float,
va: str = "bottom",
ha: str = "center",
) -> None:
"""
Add a percentage change annotation to the plot.
Expand All @@ -181,9 +211,16 @@ def add_percentage_annotation(
va: Vertical alignment
ha: Horizontal alignment
"""
logger.debug("Adding percentage annotation at (x=%f, y=%f) with value: %f%%", x, y, percentage)
logger.debug(
"Adding percentage annotation at (x=%f, y=%f) with value: %f%%",
x,
y,
percentage,
)
try:
color = self.colors["positive"] if percentage >= 0 else self.colors["negative"]
color = (
self.colors["positive"] if percentage >= 0 else self.colors["negative"]
)
ax.text(
x,
y,
Expand All @@ -199,7 +236,9 @@ def add_percentage_annotation(
logger.error("Failed to add percentage annotation: %s", str(e))
raise

def add_legend(self, ax: plt.Axes, loc: str = "best", title: Optional[str] = None) -> None:
def add_legend(
self, ax: plt.Axes, loc: str = "best", title: Optional[str] = None
) -> None:
"""
Add a formatted legend to the plot.
Expand All @@ -210,24 +249,34 @@ def add_legend(self, ax: plt.Axes, loc: str = "best", title: Optional[str] = Non
"""
logger.debug("Adding legend with location: %s and title: %s", loc, title)
try:
legend = ax.legend(fontsize=self.font_sizes["legend"], loc=loc, framealpha=self.alpha["annotation"])
legend = ax.legend(
fontsize=self.font_sizes["legend"],
loc=loc,
framealpha=self.alpha["annotation"],
)
if title:
legend.set_title(title, prop={"size": self.font_sizes["legend"]})
logger.debug("Legend added successfully")
except Exception as e:
logger.error("Failed to add legend: %s", str(e))
raise

def finalize_figure(self, tight_layout: bool = True, adjust_spacing: bool = False) -> None:
def finalize_figure(
self, tight_layout: bool = True, adjust_spacing: bool = False
) -> None:
"""
Apply final formatting to the current figure.
Args:
tight_layout: Whether to apply tight_layout
adjust_spacing: Whether to adjust subplot spacing
"""
logger.info("Finalizing figure with tight_layout=%s, adjust_spacing=%s", tight_layout, adjust_spacing)

logger.info(
"Finalizing figure with tight_layout=%s, adjust_spacing=%s",
tight_layout,
adjust_spacing,
)

if self.current_figure is None:
logger.warning("No current figure to finalize")
return
Expand All @@ -236,43 +285,14 @@ def finalize_figure(self, tight_layout: bool = True, adjust_spacing: bool = Fals
if tight_layout:
self.current_figure.tight_layout(pad=self.spacing["tight_layout_pad"])
if adjust_spacing:
self.current_figure.subplots_adjust(hspace=self.spacing["subplot_adjust_hspace"])
self.current_figure.subplots_adjust(
hspace=self.spacing["subplot_adjust_hspace"]
)
logger.debug("Figure finalization completed successfully")
except Exception as e:
logger.error("Failed to finalize figure: %s", str(e))
raise

def save_plot(self, filename: Union[str, Path], dpi: int = 300, cleanup: bool = True) -> None:
"""
Save the current plot to a file.
Args:
filename: Path to save the plot
dpi: Resolution for saved plot
cleanup: Whether to close the plot after saving
"""
logger.info("Saving plot to: %s with DPI: %d", filename, dpi)

if self.current_figure is None:
error_msg = "No current figure to save"
logger.error(error_msg)
raise ValueError(error_msg)

try:
filepath = Path(filename)
filepath.parent.mkdir(parents=True, exist_ok=True)
logger.debug("Created directory structure for: %s", filepath.parent)

self.current_figure.savefig(filepath, dpi=dpi, bbox_inches="tight", facecolor="white", edgecolor="none")
logger.info("Plot saved successfully to: %s", filepath)

if cleanup:
logger.debug("Cleaning up after save")
self.cleanup()
except Exception as e:
logger.error("Failed to save plot to %s: %s", filename, str(e))
raise

def cleanup(self) -> None:
"""Close the current plot and clear matplotlib memory."""
logger.debug("Performing cleanup")
Expand All @@ -282,4 +302,89 @@ def cleanup(self) -> None:
self.current_axes = None
logger.debug("Cleanup completed successfully")
else:
logger.debug("No figure to clean up")
logger.debug("No figure to clean up")

@staticmethod
def export_plots_base64(
export_location: Union[str, Path], plots: Dict[str, str], dpi: int = 300
) -> None:
logger.info("Exporting base64 plots to: %s", export_location)
export_path = Path(export_location)
export_path.mkdir(parents=True, exist_ok=True)

for plot_name, base64_str in plots.items():
filename = export_path / f"{plot_name}.png"
logger.debug("Saving base64 plot: %s to %s", plot_name, filename)
try:
image_data = base64.b64decode(base64_str)
with open(filename, "wb") as f:
f.write(image_data)
logger.info("Base64 plot %s saved successfully", plot_name)
except Exception as e:
logger.error("Failed to save base64 plot %s: %s", plot_name, str(e))
raise
pass

@staticmethod
def export_plots_fig(
export_location: Union[str, Path], plots: Dict[str, plt.Figure], dpi: int = 300
) -> None:
"""
Save multiple plots to the specified location.
Args:
export_location: Directory to save the plots
plots: Dictionary of plot names and their corresponding figures
dpi: Resolution for saved plots
"""
logger.info("Saving multiple plots to: %s", export_location)
export_path = Path(export_location)
export_path.mkdir(parents=True, exist_ok=True)

for plot_name, fig in plots.items():
filename = export_path / f"{plot_name}.png"
logger.debug("Saving plot: %s to %s", plot_name, filename)
try:
fig.savefig(
filename,
dpi=dpi,
bbox_inches="tight",
facecolor="white",
edgecolor="none",
)
logger.info("Plot %s saved successfully", plot_name)
except Exception as e:
logger.error("Failed to save plot %s: %s", plot_name, str(e))
raise

@abstractmethod
def plot_all(
self, display_plots: bool = True, export_location: Union[str, Path] = None
) -> None:
pass

@staticmethod
def display_plot(plot_collect: Dict[str, plt.Figure]) -> None:
"""Display the plot."""
for plot_name, fig in plot_collect.items():
fig.show()

@staticmethod
def _display_base64_image(base64_image: str):
"""Helper method to display a base64-encoded image."""
display(Image(data=base64.b64decode(base64_image)))

def convert_plot_to_base64(self, fig: plt.Figure) -> str:
logger.debug("Converting plot to base64")
try:
buffer = io.BytesIO()
fig.savefig(buffer, format="png")
buffer.seek(0)
image_png = buffer.getvalue()
buffer.close()
graphic = base64.b64encode(image_png)
logger.debug("Successfully converted plot to base64")
return graphic.decode("utf-8")
except Exception as e:
logger.error("Failed to convert plot to base64: %s", str(e), exc_info=True)
raise

0 comments on commit 0178c0b

Please sign in to comment.