-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Densityplot: add support for discrete variables #2878
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,9 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
colors='cycle', outline=True, hpd_markers='', shade=0., bw=4.5, figsize=None, | ||
textsize=12, plot_transformed=False, ax=None): | ||
""" | ||
Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of | ||
traces. KDE plots are grouped per variable and colors assigned to models. | ||
Generates KDE plots for continuous variables and histograms for discretes ones. | ||
Plots are truncated at their 100*(1-alpha)% credible intervals. Plots are grouped | ||
per variable and colors assigned to models. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -32,11 +33,11 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
Defaults to 'mean'. | ||
colors : list or string, optional | ||
List with valid matplotlib colors, one color per model. Alternative a string can be passed. | ||
If the string is `cycle `, it will automatically choose a color per model from matplolib's | ||
If the string is `cycle`, it will automatically choose a color per model from matplolib's | ||
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all | ||
models. Defaults to 'C0' (blueish in most matplotlib styles) | ||
models. Defaults to `cycle`. | ||
outline : boolean | ||
Use a line to draw the truncated KDE and. Defaults to True | ||
Use a line to draw KDEs and histograms. Default to True | ||
hpd_markers : str | ||
A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval. | ||
Defaults to empty string (no marker). | ||
|
@@ -64,7 +65,7 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
|
||
""" | ||
if point_estimate not in ('mean', 'median', None): | ||
raise ValueError("Point estimate should be 'mean' or 'median'") | ||
raise ValueError("Point Estimate should be 'mean', 'median' or None") | ||
|
||
if not isinstance(trace, (list, tuple)): | ||
trace = [trace] | ||
|
@@ -77,7 +78,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
else: | ||
models = [''] | ||
elif len(models) != lenght_trace: | ||
raise ValueError("The number of names for the models does not match the number of models") | ||
raise ValueError( | ||
"The number of names for the models does not match the number of models") | ||
|
||
lenght_models = len(models) | ||
|
||
|
@@ -97,8 +99,8 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
if figsize is None: | ||
figsize = (6, len(varnames) * 2) | ||
|
||
fig, kplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize) | ||
kplot = kplot.flatten() | ||
fig, dplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize) | ||
dplot = dplot.flatten() | ||
|
||
for v_idx, vname in enumerate(varnames): | ||
for t_idx, tr in enumerate(trace): | ||
|
@@ -108,23 +110,24 @@ def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='m | |
if k > 1: | ||
vec = np.split(vec.T.ravel(), k) | ||
for i in range(k): | ||
_kde_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate, | ||
hpd_markers, outline, shade, kplot[v_idx]) | ||
_d_helper(vec[i], vname, colors[t_idx], bw, alpha, point_estimate, | ||
hpd_markers, outline, shade, dplot[v_idx]) | ||
|
||
else: | ||
_kde_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate, | ||
hpd_markers, outline, shade, kplot[v_idx]) | ||
_d_helper(vec, vname, colors[t_idx], bw, alpha, point_estimate, | ||
hpd_markers, outline, shade, dplot[v_idx]) | ||
|
||
if lenght_trace > 1: | ||
for m_idx, m in enumerate(models): | ||
kplot[0].plot([], label=m, c=colors[m_idx]) | ||
kplot[0].legend(fontsize=textsize) | ||
dplot[0].plot([], label=m, c=colors[m_idx]) | ||
dplot[0].legend(fontsize=textsize) | ||
|
||
fig.tight_layout() | ||
|
||
return kplot | ||
return dplot | ||
|
||
|
||
def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax): | ||
def _d_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, shade, ax): | ||
""" | ||
vec : array | ||
1D array from trace | ||
|
@@ -145,34 +148,43 @@ def _kde_helper(vec, vname, c, bw, alpha, point_estimate, hpd_markers, outline, | |
(opaque). Defaults to 0. | ||
ax : matplotlib axes | ||
""" | ||
density, l, u = fast_kde(vec, bw) | ||
x = np.linspace(l, u, len(density)) | ||
hpd_ = hpd(vec, alpha) | ||
cut = (x >= hpd_[0]) & (x <= hpd_[1]) | ||
|
||
xmin = x[cut][0] | ||
xmax = x[cut][-1] | ||
ymin = density[cut][0] | ||
ymax = density[cut][-1] | ||
|
||
if outline: | ||
ax.plot(x[cut], density[cut], color=c) | ||
ax.plot([xmin, xmin], [-0.5, ymin], color=c, ls='-') | ||
ax.plot([xmax, xmax], [-0.5, ymax], color=c, ls='-') | ||
if vec.dtype.kind == 'f': | ||
density, l, u = fast_kde(vec) | ||
x = np.linspace(l, u, len(density)) | ||
hpd_ = hpd(vec, alpha) | ||
cut = (x >= hpd_[0]) & (x <= hpd_[1]) | ||
|
||
xmin = x[cut][0] | ||
xmax = x[cut][-1] | ||
ymin = density[cut][0] | ||
ymax = density[cut][-1] | ||
|
||
if outline: | ||
ax.plot(x[cut], density[cut], color=c) | ||
ax.plot([xmin, xmin], [-ymin/100, ymin], color=c, ls='-') | ||
ax.plot([xmax, xmax], [-ymax/100, ymax], color=c, ls='-') | ||
|
||
if shade: | ||
ax.fill_between(x, density, where=cut, color=c, alpha=shade) | ||
|
||
else: | ||
xmin, xmax = hpd(vec, alpha) | ||
bins = range(xmin, xmax+1) | ||
if outline: | ||
ax.hist(vec, bins=bins, color=c, histtype='step') | ||
ax.hist(vec, bins=bins, color=c, alpha=shade) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will only change the scale of y. But that scale is not plotted, the yticks are set to an empty list. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Fine with me. |
||
ax.set_xticks(bins) | ||
|
||
if hpd_markers: | ||
ax.plot(xmin, 0, 'v', color=c, markeredgecolor='k') | ||
ax.plot(xmax, 0, 'v', color=c, markeredgecolor='k') | ||
|
||
if shade: | ||
ax.fill_between(x, density, where=cut, color=c, alpha=shade) | ||
|
||
if point_estimate is not None: | ||
if point_estimate == 'mean': | ||
ps = np.mean(vec) | ||
if point_estimate == 'median': | ||
elif point_estimate == 'median': | ||
ps = np.median(vec) | ||
ax.plot(ps, 0, 'o', color=c, markeredgecolor='k') | ||
ax.plot(ps, -0.001, 'o', color=c, markeredgecolor='k') | ||
|
||
ax.set_yticks([]) | ||
ax.set_title(vname) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
Estimate
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a typo