Skip to content

Commit

Permalink
Fixing rendering problem (qutip#2442)
Browse files Browse the repository at this point in the history
Resolved a rendering issue in the process matrix visualization. Previously, the code did not utilize matplotlib's built-in z-sorting mechanism. Experiments with various z-sort configurations (min, max, average) yielded inconsistent results across different charts. The solution was inspired by a Stack Overflow discussion (https://stackoverflow.com/questions/18602660/matplotlib-bar3d-clipping-problems). By adjusting the calculation of camera coordinates and incorporating minor modifications from the suggested approach, the rendering issue has been successfully addressed.
  • Loading branch information
anushkrishnav authored Jun 8, 2024
1 parent 449c0cc commit c874c4a
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 59 deletions.
3 changes: 3 additions & 0 deletions doc/changes/2400.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Bug Fix in Process Matrix Rendering

Resolved a rendering issue in the process matrix visualization. Previously, the code did not utilize matplotlib's built-in z-sorting mechanism. Experiments with various z-sort configurations (min, max, average) yielded inconsistent results across different charts. The solution was inspired by a Stack Overflow discussion (https://stackoverflow.com/questions/18602660/matplotlib-bar3d-clipping-problems). By adjusting the calculation of camera coordinates and incorporating minor modifications from the suggested approach, the rendering issue has been successfully addressed.
196 changes: 137 additions & 59 deletions qutip/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,15 @@
'plot_spin_distribution', 'complex_array_to_rgb',
'plot_qubism', 'plot_schmidt']

import warnings
import itertools as it
import numpy as np
from numpy import pi, array, sin, cos, angle, log2, sqrt

from packaging.version import parse as parse_version

from . import (
Qobj, isket, ket2dm, tensor, vector_to_operator, to_super, settings
Qobj, isket, ket2dm, tensor, vector_to_operator, settings
)
from .core.dimensions import flatten
from .core.superop_reps import _to_superpauli, isqubitdims
from .wigner import wigner
from .matplotlib_utilities import complex_phase_cmap
Expand Down Expand Up @@ -670,10 +668,54 @@ def _get_matrix_components(option, M, argument):
f"{option} for {argument}")


def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,
bar_style='real', color_limits=None, color_style='real',
options=None, *, cmap=None, colorbar=True,
fig=None, ax=None):
def sph2cart(r, theta, phi):
"""spherical to cartesian transformation."""
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return x, y, z


def sphview(ax):
"""
returns the camera position for 3D axes in spherical coordinates."""
xlim = ax.get_xlim()
ylim = ax.get_ylim()
zlim = ax.get_zlim()
# Compute based on the plots xyz limits.
r = 0.5 * np.sqrt(
(xlim[1] - xlim[0]) ** 2 +
(ylim[1] - ylim[0]) ** 2 +
(zlim[1] - zlim[0]) ** 2
)
theta, phi = np.radians((90 - ax.elev, ax.azim))
return r, theta, phi


def get_camera_position(ax):
"""
returns the camera position for 3D axes in cartesian coordinates
as a 3d numpy array.
"""
r, theta, phi = sphview(ax)
return np.array(sph2cart(r, theta, phi), ndmin=3).T


def matrix_histogram(
M,
x_basis=None,
y_basis=None,
limits=None,
bar_style="real",
color_limits=None,
color_style="real",
options=None,
*,
cmap=None,
colorbar=True,
fig=None,
ax=None,
):
"""
Draw a histogram for the matrix M, with the given x and y labels and title.
Expand Down Expand Up @@ -791,11 +833,20 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,
"""

# default options
default_opts = {'zticks': None, 'bars_spacing': 0.2,
'bars_alpha': 1., 'bars_lw': 0.5, 'bars_edgecolor': 'k',
'shade': True, 'azim': -35, 'elev': 35, 'stick': False,
'cbar_pad': 0.04, 'cbar_to_z': False, 'threshold': None}
default_opts = {
"zticks": None,
"bars_spacing": 0.3,
"bars_alpha": 1.0,
"bars_lw": 0.7,
"bars_edgecolor": "k",
"shade": True,
"azim": -60,
"elev": 30,
"stick": False,
"cbar_pad": 0.04,
"cbar_to_z": False,
"threshold": None,
}

# update default_opts from input options
if options is None:
Expand All @@ -804,16 +855,18 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,
if isinstance(options, dict):
# check if keys in options dict are valid
if set(options) - set(default_opts):
raise ValueError("invalid key(s) found in options: "
f"{', '.join(set(options) - set(default_opts))}")
raise ValueError(
"invalid key(s) found in options: "
f"{', '.join(set(options) - set(default_opts))}"
)
else:
# updating default options
default_opts.update(options)
options = default_opts
else:
raise ValueError("options must be a dictionary")

fig, ax = _is_fig_and_ax(fig, ax, projection='3d')
fig, ax = _is_fig_and_ax(fig, ax, projection="3d")

if not isinstance(M, list):
Ms = [M]
Expand All @@ -822,8 +875,7 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,

_equal_shape(Ms)

for i in range(len(Ms)):
M = Ms[i]
for i, M in enumerate(Ms):
if isinstance(M, Qobj):
if x_basis is None:
x_basis = list(_cb_labels([M.shape[0]])[0])
Expand All @@ -832,10 +884,9 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,
# extract matrix data from Qobj
M = M.full()

bar_M = _get_matrix_components(bar_style, M, 'bar_style')
bar_M = _get_matrix_components(bar_style, M, "bar_style")

if isinstance(limits, list) and \
len(limits) == 2:
if isinstance(limits, list) and len(limits) == 2:
z_min = limits[0]
z_max = limits[1]
else:
Expand All @@ -846,19 +897,18 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,
z_min -= 0.1
z_max += 0.1

color_M = _get_matrix_components(color_style, M, 'color_style')
color_M = _get_matrix_components(color_style, M, "color_style")

if isinstance(color_limits, list) and \
len(color_limits) == 2:
if isinstance(color_limits, list) and len(color_limits) == 2:
c_min = color_limits[0]
c_max = color_limits[1]
else:
if color_style == 'phase':
if color_style == "phase":
c_min = -pi
c_max = pi
else:
c_min = min(color_M) if i == 0 else min(min(color_M), c_min)
c_max = min(color_M) if i == 0 else max(max(color_M), c_max)
c_max = max(color_M) if i == 0 else max(max(color_M), c_max)

if c_min == c_max:
c_min -= 0.1
Expand All @@ -868,66 +918,93 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,

if cmap is None:
# change later
if color_style == 'phase':
if color_style == "phase":
cmap = _cyclic_cmap()
else:
cmap = _sequential_cmap()

artist_list = list()

ax.view_init(azim=options['azim'], elev=options['elev'])

camera = get_camera_position(ax)
for M in Ms:

if isinstance(M, Qobj):
M = M.full()

bar_M = _get_matrix_components(bar_style, M, 'bar_style')
color_M = _get_matrix_components(color_style, M, 'color_style')
bar_M = _get_matrix_components(bar_style, M, "bar_style")
color_M = _get_matrix_components(color_style, M, "color_style")

n = np.size(M)
xpos, ypos = np.meshgrid(range(M.shape[0]), range(M.shape[1]))
xpos = xpos.T.flatten() + 0.5
ypos = ypos.T.flatten() + 0.5
zpos = np.zeros(n)
dx = dy = (1 - options['bars_spacing']) * np.ones(n)
dx = dy = (1 - options["bars_spacing"]) * np.ones(n)
colors = cmap(norm(color_M))

colors[:, 3] = options['bars_alpha']
colors[:, 3] = options["bars_alpha"]

if options['threshold'] is not None:
colors[:, 3] *= 1 * (bar_M >= options['threshold'])
if options["threshold"] is not None:
colors[:, 3] *= 1 * (bar_M >= options["threshold"])

idx, = np.where(bar_M < options['threshold'])
(idx,) = np.where(bar_M < options["threshold"])
bar_M[idx] = 0

artist = ax.bar3d(xpos, ypos, zpos, dx, dy, bar_M, color=colors,
edgecolors=options['bars_edgecolor'],
linewidths=options['bars_lw'],
shade=options['shade'])
artist_list.append([artist])
temp_xpos = xpos.reshape(M.shape)
temp_ypos = ypos.reshape(M.shape)
temp_zpos = zpos.reshape(M.shape)

# calculating z_order for each bar based on its position
# The sorting issue was fixed by making minor change to
# https://stackoverflow.com/questions/18602660/matplotlib-bar3d-clipping-problems
z_order = (
np.multiply(
[
temp_xpos, temp_ypos, temp_zpos], camera
).sum(0).flatten()
)

for i, uxpos in enumerate(xpos):
artist = ax.bar3d(
uxpos,
ypos[i],
zpos[i],
dx[i],
dy[i],
bar_M[i],
color=colors[i],
edgecolors=options["bars_edgecolor"],
linewidths=options["bars_lw"],
shade=options["shade"],
)
# Setting the z-order for rendering
artist._sort_zpos = z_order[i]
artist_list.append([artist])

if len(Ms) == 1:
output = ax
else:
output = animation.ArtistAnimation(fig, artist_list, interval=50,
blit=True, repeat_delay=1000)
output = animation.ArtistAnimation(
fig, artist_list, interval=50, blit=True, repeat_delay=1000
)

# remove vertical lines on xz and yz plane
ax.yaxis._axinfo["grid"]['linewidth'] = 0
ax.xaxis._axinfo["grid"]['linewidth'] = 0
ax.yaxis._axinfo["grid"]["linewidth"] = 0
ax.xaxis._axinfo["grid"]["linewidth"] = 0

# x axis
_update_xaxis(options['bars_spacing'], M, ax, x_basis)
_update_xaxis(options["bars_spacing"], M, ax, x_basis)

# y axis
_update_yaxis(options['bars_spacing'], M, ax, y_basis)
_update_yaxis(options["bars_spacing"], M, ax, y_basis)

# z axis
_update_zaxis(ax, z_min, z_max, options['zticks'])
_update_zaxis(ax, z_min, z_max, options["zticks"])

# stick to xz and yz plane
_stick_to_planes(options['stick'],
options['azim'], ax, M,
options['bars_spacing'])
ax.view_init(azim=options['azim'], elev=options['elev'])
_stick_to_planes(options["stick"], options["azim"], ax, M, options["bars_spacing"])

# removing margins
_remove_margins(ax.xaxis)
Expand All @@ -936,22 +1013,23 @@ def matrix_histogram(M, x_basis=None, y_basis=None, limits=None,

# color axis
if colorbar:
cax, kw = mpl.colorbar.make_axes(ax, shrink=.75,
pad=options['cbar_pad'])
cax, kw = mpl.colorbar.make_axes(
ax, shrink=0.75, pad=options["cbar_pad"])
cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm)

if color_style == 'real':
cb.set_label('real')
elif color_style == 'img':
cb.set_label('imaginary')
elif color_style == 'abs':
cb.set_label('absolute')
if color_style == "real":
cb.set_label("real")
elif color_style == "img":
cb.set_label("imaginary")
elif color_style == "abs":
cb.set_label("absolute")
else:
cb.set_label('arg')
cb.set_label("arg")
if color_limits is None:
cb.set_ticks([-pi, -pi / 2, 0, pi / 2, pi])
cb.set_ticklabels(
(r'$-\pi$', r'$-\pi/2$', r'$0$', r'$\pi/2$', r'$\pi$'))
(r"$-\pi$", r"$-\pi/2$", r"$0$", r"$\pi/2$", r"$\pi$")
)

return fig, output

Expand Down

0 comments on commit c874c4a

Please sign in to comment.