-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sample_prior function #2876
Changes from all commits
e99207d
6e75ed6
99f71a3
8c0d9b3
ea3026b
16c0b3a
9a80d0d
b45d6e7
4541d16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,12 @@ | |
|
||
from .backends.base import BaseTrace, MultiTrace | ||
from .backends.ndarray import NDArray | ||
from .distributions import draw_values | ||
from .model import modelcontext, Point, all_continuous | ||
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, | ||
BinaryGibbsMetropolis, CategoricalGibbsMetropolis, | ||
Slice, CompoundStep, arraystep) | ||
from .util import update_start_vals | ||
from .util import update_start_vals, is_transformed_name, get_untransformed_name, get_default_varnames | ||
from .vartypes import discrete_types | ||
from pymc3.step_methods.hmc import quadpotential | ||
from pymc3 import plots | ||
|
@@ -25,7 +26,7 @@ | |
import sys | ||
sys.setrecursionlimit(10000) | ||
|
||
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts'] | ||
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts', 'sample_generative'] | ||
|
||
STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, | ||
BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis) | ||
|
@@ -1206,6 +1207,50 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None, | |
return {k: np.asarray(v) for k, v in ppc.items()} | ||
|
||
|
||
def sample_generative(samples=500, model=None, vars=None, random_seed=None): | ||
"""Generate samples from the generative model. | ||
|
||
Parameters | ||
---------- | ||
samples : int | ||
Number of samples from the prior to generate. Defaults to 500. | ||
model : Model (optional if in `with` context) | ||
vars : iterable | ||
Variables for which to compute the posterior predictive samples. | ||
Defaults to `model.named_vars`. | ||
random_seed : int | ||
Seed for the random number generator. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary with the variables as keys. The values are arrays of prior samples. | ||
""" | ||
model = modelcontext(model) | ||
|
||
if vars is None: | ||
vars = set(model.named_vars.keys()) | ||
|
||
if random_seed is not None: | ||
np.random.seed(random_seed) | ||
|
||
names = get_default_varnames(model.named_vars, include_transformed=False) | ||
# draw_values fails with auto-transformed variables. transform them later! | ||
values = draw_values([model[name] for name in names], size=samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow this is really efficient! However, is it sure that the values draw in the children of a graph is depending on the samples from their parent? In the previous implementation, we always sample by evaluating a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple example: X = theano.shared(np.arange(3))
with pm.Model() as m:
ind = pm.Categorical('i', np.ones(3)/3)
x = pm.Deterministic('X', X[ind])
prior=pm.sample_generative(10)
prior
{'X': array([0, 0, 2, 1, 2, 2, 1, 0, 0, 1]),
'i': array([1, 0, 0, 2, 2, 0, 2, 1, 0, 0])}
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a super helpful example! Let me take a look at it -- there's some work already to avoid some edge cases, and I would have thought this got caught. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Caught the bug (will add tests for all this, too). Your example runs as desired now! I make sure I evaluate the params by making a dictionary of index integers to nodes (avoids non-hashability of |
||
|
||
data = {k: v for k, v in zip(names, values)} | ||
|
||
prior = {} | ||
for var_name in vars: | ||
if var_name in data: | ||
prior[var_name] = data[var_name] | ||
elif is_transformed_name(var_name): | ||
untransformed = get_untransformed_name(var_name) | ||
if untransformed in data: | ||
prior[var_name] = model[untransformed].transformation.forward(data[untransformed]).eval() | ||
return prior | ||
|
||
|
||
def init_nuts(init='auto', chains=1, n_init=500000, model=None, | ||
random_seed=None, progressbar=True, **kwargs): | ||
"""Set up the mass matrix initialization for NUTS. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
size
seems to only be imposed toparam
's with arandom
method, and we hope the content ofvalues
to be the rightsize
in the end. Shouldn't there be some enforcement of thesize
, for thenumbers.Number
,np.ndarray
,tt.TensorConstant
,tt.sharedvar.SharedVariable
andtt.TensorVariable
in point cases for us to be sure thatvalues
will in fact have the desired output size?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am relying here on
theano
catching those sorts of errors, and giving more informative errors than I could. I am running this on a few different models to make sure it gives reasonable results, but so far those sorts of inputs get broadcast in a sensible manner.