Skip to content

Commit

Permalink
BUG: fix orientation of colorbar in multiplanel plot (#4981)
Browse files Browse the repository at this point in the history
  • Loading branch information
xshaokun authored Sep 13, 2024
1 parent 8170d00 commit e22ab1f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
24 changes: 20 additions & 4 deletions yt/visualization/base_plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,20 @@ def _set_axes(self) -> None:
self.image.axes.set_facecolor(self.colorbar_handler.background_color)

self.cax.tick_params(which="both", direction="in")
self.cb = self.figure.colorbar(self.image, self.cax)

# For creating a multipanel plot by ImageGrid
# we may need the location keyword, which requires Matplotlib >= 3.7.0
cb_location = getattr(self.cax, "orientation", None)
if matplotlib.__version_info__ >= (3, 7):
self.cb = self.figure.colorbar(self.image, self.cax, location=cb_location)
else:
if cb_location in ["top", "bottom"]:
warnings.warn(
"Cannot properly set the orientation of colorbar. "
"Consider upgrading matplotlib to version 3.7 or newer",
stacklevel=6,
)
self.cb = self.figure.colorbar(self.image, self.cax)

cb_axis: Axis
if self.cb.orientation == "vertical":
Expand Down Expand Up @@ -526,9 +539,12 @@ def _toggle_colorbar(self, choice: bool):

def _get_labels(self):
labels = super()._get_labels()
cbax = self.cb.ax
labels += cbax.yaxis.get_ticklabels()
labels += [cbax.yaxis.label, cbax.yaxis.get_offset_text()]
if getattr(self.cb, "orientation", "vertical") == "horizontal":
cbaxis = self.cb.ax.xaxis
else:
cbaxis = self.cb.ax.yaxis
labels += cbaxis.get_ticklabels()
labels += [cbaxis.label, cbaxis.get_offset_text()]
return labels

def hide_axes(self, *, draw_frame=None):
Expand Down
34 changes: 34 additions & 0 deletions yt/visualization/tests/test_image_comp_2D_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,40 @@ def test_particleprojectionplot_set_colorbar_properties():
return p.plots[field].figure


class TestMultipanelPlot:
@classmethod
def setup_class(cls):
cls.fields = [
("gas", "density"),
("gas", "velocity_x"),
("gas", "velocity_y"),
("gas", "velocity_magnitude"),
]
cls.ds = fake_random_ds(16)

@pytest.mark.skipif(
mpl.__version_info__ < (3, 7),
reason="colorbar cannot currently be set horizontal in multi-panel plot with matplotlib older than 3.7.0",
)
@pytest.mark.parametrize("cbar_location", ["top", "bottom", "left", "right"])
@pytest.mark.mpl_image_compare
def test_multipanelplot_colorbar_orientation_simple(self, cbar_location):
p = SlicePlot(self.ds, "z", self.fields)
return p.export_to_mpl_figure((2, 2), cbar_location=cbar_location)

@pytest.mark.parametrize("cbar_location", ["top", "bottom"])
def test_multipanelplot_colorbar_orientation_warning(self, cbar_location):
p = SlicePlot(self.ds, "z", self.fields)
if mpl.__version_info__ < (3, 7):
with pytest.warns(
UserWarning,
match="Cannot properly set the orientation of colorbar.",
):
p.export_to_mpl_figure((2, 2), cbar_location=cbar_location)
else:
p.export_to_mpl_figure((2, 2), cbar_location=cbar_location)


class TestProfilePlot:
@classmethod
def setup_class(cls):
Expand Down

0 comments on commit e22ab1f

Please sign in to comment.