Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Densityplot: add support for discrete variables #2878

Merged
merged 3 commits into from
Mar 2, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 48 additions & 36 deletions pymc3/plots/densityplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None,
textsize=12, plot_transformed=False, ax=None):
"""
Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of
traces. KDE plots are grouped per variable and colors assigned to models.
Generates KDE plots for continuous variables and histograms for discretes ones.
Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped
per variable and colors assigned to models.

Parameters
----------
Expand All @@ -32,11 +33,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
Defaults to 'mean'.
colors : list or string, optional
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
If the string is `cycle `, it will automatically choose a color per model from matplolib's
If the string is `cycle`, it will automatically choose a color per model from matplolib's
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
models. Defaults to 'C0' (blueish in most matplotlib styles)
models. Defaults to `cycle`.
outline : boolean
Use a line to draw the truncated KDE and. Defaults to True
Use a line to draw KDEs and histograms. Default to True
hpd_markers : str
A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval.
Defaults to empty string (no marker).
Expand Down Expand Up @@ -64,7 +65,7 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m

"""
if point_estimate not in ('mean', 'median', None):
raise ValueError("Point estimate should be 'mean' or 'median'")
raise ValueError("Point Estimate should be 'mean', 'median' or None")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why Estimate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a typo


if not isinstance(trace, (list, tuple)):
trace = [trace]
Expand All @@ -77,7 +78,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
else:
models = ['']
elif len(models) != lenght_trace:
raise ValueError("The number of names for the models does not match the number of models")
raise ValueError(
"The number of names for the models does not match the number of models")

lenght_models = len(models)

Expand All @@ -97,8 +99,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
if figsize is None:
figsize = (6, len(varnames) * 2)

fig, kplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
kplot = kplot.flatten()
fig, dplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize)
dplot = dplot.flatten()

for v_idx, vname in enumerate(varnames):
for t_idx, tr in enumerate(trace):
Expand All @@ -108,23 +110,24 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m
if k > 1:
vec = np.split(vec.T.ravel(), k)
for i in range(k):
_kde_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate,
hpd_markers, outline, shade, kplot[v_idx])
_d_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate,
hpd_markers, outline, shade, dplot[v_idx])

else:
_kde_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate,
hpd_markers, outline, shade, kplot[v_idx])
_d_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate,
hpd_markers, outline, shade, dplot[v_idx])

if lenght_trace > 1:
for m_idx, m in enumerate(models):
kplot[0].plot([], label=m, c=colors[m_idx])
kplot[0].legend(fontsize=textsize)
dplot[0].plot([], label=m, c=colors[m_idx])
dplot[0].legend(fontsize=textsize)

fig.tight_layout()

return kplot
return dplot


def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax):
def _d_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax):
"""
vec : array
1D array from trace
Expand All @@ -145,34 +148,43 @@ def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline,
(opaque). Defaults to 0.
ax : matplotlib axes
"""
density, l, u = fast_kde(vec, bw)
x = np.linspace(l, u, len(density))
hpd_ = hpd(vec, alpha)
cut = (x >= hpd_[0]) & (x <= hpd_[1])

xmin = x[cut][0]
xmax = x[cut][-1]
ymin = density[cut][0]
ymax = density[cut][-1]

if outline:
ax.plot(x[cut], density[cut], color=c)
ax.plot([xmin, xmin], [-0.5, ymin], color=c, ls='-')
ax.plot([xmax, xmax], [-0.5, ymax], color=c, ls='-')
if vec.dtype.kind == 'f':
density, l, u = fast_kde(vec)
x = np.linspace(l, u, len(density))
hpd_ = hpd(vec, alpha)
cut = (x >= hpd_[0]) & (x <= hpd_[1])

xmin = x[cut][0]
xmax = x[cut][-1]
ymin = density[cut][0]
ymax = density[cut][-1]

if outline:
ax.plot(x[cut], density[cut], color=c)
ax.plot([xmin, xmin], [-ymin/100, ymin], color=c, ls='-')
ax.plot([xmax, xmax], [-ymax/100, ymax], color=c, ls='-')

if shade:
ax.fill_between(x, density, where=cut, color=c, alpha=shade)

else:
xmin, xmax = hpd(vec, alpha)
bins = range(xmin, xmax+1)
if outline:
ax.hist(vec, bins=bins, color=c, histtype='step')
ax.hist(vec, bins=bins, color=c, alpha=shade)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to set density=None to True.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will only change the scale of y. But that scale is not plotted, the yticks are set to an empty list.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Fine with me.

ax.set_xticks(bins)

if hpd_markers:
ax.plot(xmin, 0, 'v', color=c, markeredgecolor='k')
ax.plot(xmax, 0, 'v', color=c, markeredgecolor='k')

if shade:
ax.fill_between(x, density, where=cut, color=c, alpha=shade)

if point_estimate is not None:
if point_estimate == 'mean':
ps = np.mean(vec)
if point_estimate == 'median':
elif point_estimate == 'median':
ps = np.median(vec)
ax.plot(ps, 0, 'o', color=c, markeredgecolor='k')
ax.plot(ps, -0.001, 'o', color=c, markeredgecolor='k')

ax.set_yticks([])
ax.set_title(vname)
Expand Down