-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sample_prior function #2876
Changes from 8 commits
e99207d
6e75ed6
99f71a3
8c0d9b3
ea3026b
16c0b3a
9a80d0d
b45d6e7
4541d16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -210,7 +210,7 @@ def random(self, *args, **kwargs): | |
|
||
|
||
|
||
def draw_values(params, point=None): | ||
def draw_values(params, point=None, size=None): | ||
""" | ||
Draw (fix) parameter values. Handles a number of cases: | ||
|
||
|
@@ -254,7 +254,7 @@ def draw_values(params, point=None): | |
|
||
# Init givens and the stack of nodes to try to `_draw_value` from | ||
givens = {} | ||
stored = set([]) # Some nodes | ||
stored = set() # Some nodes | ||
stack = list(leaf_nodes.values()) # A queue would be more appropriate | ||
while stack: | ||
next_ = stack.pop(0) | ||
|
@@ -279,13 +279,14 @@ def draw_values(params, point=None): | |
# The named node's children givens values must also be taken | ||
# into account. | ||
children = named_nodes_children[next_] | ||
temp_givens = [givens[k] for k in givens.keys() if k in children] | ||
temp_givens = [givens[k] for k in givens if k in children] | ||
try: | ||
# This may fail for autotransformed RVs, which don't | ||
# have the random method | ||
givens[next_.name] = (next_, _draw_value(next_, | ||
point=point, | ||
givens=temp_givens)) | ||
givens=temp_givens, | ||
size=size)) | ||
stored.add(next_.name) | ||
except theano.gof.fg.MissingInputError: | ||
# The node failed, so we must add the node's parents to | ||
|
@@ -295,10 +296,27 @@ def draw_values(params, point=None): | |
if node is not None and | ||
node.name not in stored and | ||
node not in params]) | ||
values = [] | ||
for param in params: | ||
values.append(_draw_value(param, point=point, givens=givens.values())) | ||
return values | ||
|
||
# Funny dance to resolve missing nodes | ||
params = dict(enumerate(params)) # some nodes are not hashable | ||
evaluated = {} | ||
to_eval = set() | ||
missing_inputs = set(params) | ||
while to_eval or missing_inputs: | ||
if to_eval == missing_inputs: | ||
raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval])) | ||
to_eval = set(missing_inputs) | ||
missing_inputs = set() | ||
for param in to_eval: | ||
try: # might evaluate in a bad order, | ||
evaluated[param] = _draw_value(params[param], point=point, givens=givens.values(), size=size) | ||
if any(params[param] in j for j in named_nodes_children.values()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops I actually commented on an older commit so it shows as outdated, sorry for the mess. First off, it looks very nice, however I think this line is confusing, You're trying to see if the node There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is much nicer, thank you! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (also just checked out travis, and your suggestion will also fix failing tests) |
||
givens[params[param].name] = (params[param], evaluated[param]) | ||
except theano.gof.fg.MissingInputError: | ||
missing_inputs.add(param) | ||
|
||
|
||
return [evaluated[j] for j in params] # set the order back | ||
|
||
|
||
@memoize | ||
|
@@ -326,7 +344,7 @@ def _compile_theano_function(param, vars, givens=None): | |
allow_input_downcast=True) | ||
|
||
|
||
def _draw_value(param, point=None, givens=None): | ||
def _draw_value(param, point=None, givens=None, size=None): | ||
"""Draw a random value from a distribution or return a constant. | ||
|
||
Parameters | ||
|
@@ -355,14 +373,19 @@ def _draw_value(param, point=None, givens=None): | |
if point and hasattr(param, 'model') and param.name in point: | ||
return point[param.name] | ||
elif hasattr(param, 'random') and param.random is not None: | ||
return param.random(point=point, size=None) | ||
return param.random(point=point, size=size) | ||
elif hasattr(param, 'distribution') and hasattr(param.distribution, 'random') and param.distribution.random is not None: | ||
return param.distribution.random(point=point, size=size) | ||
else: | ||
if givens: | ||
variables, values = list(zip(*givens)) | ||
else: | ||
variables = values = [] | ||
func = _compile_theano_function(param, variables) | ||
return func(*values) | ||
if size is None: | ||
return func(*values) | ||
else: | ||
return np.array([func(*value) for value in zip(*values)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am relying here on |
||
else: | ||
raise ValueError('Unexpected type in draw_value: %s' % type(param)) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,11 +11,12 @@ | |
|
||
from .backends.base import BaseTrace, MultiTrace | ||
from .backends.ndarray import NDArray | ||
from .distributions import draw_values | ||
from .model import modelcontext, Point, all_continuous | ||
from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, | ||
BinaryGibbsMetropolis, CategoricalGibbsMetropolis, | ||
Slice, CompoundStep, arraystep) | ||
from .util import update_start_vals | ||
from .util import update_start_vals, is_transformed_name, get_untransformed_name, get_default_varnames | ||
from .vartypes import discrete_types | ||
from pymc3.step_methods.hmc import quadpotential | ||
from pymc3 import plots | ||
|
@@ -25,7 +26,7 @@ | |
import sys | ||
sys.setrecursionlimit(10000) | ||
|
||
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts'] | ||
__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts', 'sample_generative'] | ||
|
||
STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, | ||
BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis) | ||
|
@@ -1206,6 +1207,50 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None, | |
return {k: np.asarray(v) for k, v in ppc.items()} | ||
|
||
|
||
def sample_generative(samples=500, model=None, vars=None, random_seed=None): | ||
"""Generate samples from the generative model. | ||
|
||
Parameters | ||
---------- | ||
samples : int | ||
Number of samples from the prior to generate. Defaults to 500. | ||
model : Model (optional if in `with` context) | ||
vars : iterable | ||
Variables for which to compute the posterior predictive samples. | ||
Defaults to `model.named_vars`. | ||
random_seed : int | ||
Seed for the random number generator. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary with the variables as keys. The values are arrays of prior samples. | ||
""" | ||
model = modelcontext(model) | ||
|
||
if vars is None: | ||
vars = set(model.named_vars.keys()) | ||
|
||
if random_seed is not None: | ||
np.random.seed(random_seed) | ||
|
||
names = get_default_varnames(model.named_vars, include_transformed=False) | ||
# draw_values fails with auto-transformed variables. transform them later! | ||
values = draw_values([model[name] for name in names], size=samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow this is really efficient! However, is it sure that the values draw in the children of a graph is depending on the samples from their parent? In the previous implementation, we always sample by evaluating a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A simple example: X = theano.shared(np.arange(3))
with pm.Model() as m:
ind = pm.Categorical('i', np.ones(3)/3)
x = pm.Deterministic('X', X[ind])
prior=pm.sample_generative(10)
prior
{'X': array([0, 0, 2, 1, 2, 2, 1, 0, 0, 1]),
'i': array([1, 0, 0, 2, 2, 0, 2, 1, 0, 0])}
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a super helpful example! Let me take a look at it -- there's some work already to avoid some edge cases, and I would have thought this got caught. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Caught the bug (will add tests for all this, too). Your example runs as desired now! I make sure I evaluate the params by making a dictionary of index integers to nodes (avoids non-hashability of |
||
|
||
data = {k: v for k, v in zip(names, values)} | ||
|
||
prior = {} | ||
for var_name in vars: | ||
if var_name in data: | ||
prior[var_name] = data[var_name] | ||
elif is_transformed_name(var_name): | ||
untransformed = get_untransformed_name(var_name) | ||
if untransformed in data: | ||
prior[var_name] = model[untransformed].transformation.forward(data[untransformed]).eval() | ||
return prior | ||
|
||
|
||
def init_nuts(init='auto', chains=1, n_init=500000, model=None, | ||
random_seed=None, progressbar=True, **kwargs): | ||
"""Set up the mass matrix initialization for NUTS. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,3 +302,17 @@ def test_exec_nuts_init(method): | |
assert len(start) == 2 | ||
assert isinstance(start[0], dict) | ||
assert 'a' in start[0] and 'b_log__' in start[0] | ||
|
||
|
||
def test_sample_generative(): | ||
observed = np.random.normal(10, 1, size=200) | ||
with pm.Model(): | ||
# Use a prior that's way off to show we're actually sampling from it | ||
mu = pm.Normal('mu', mu=-10, sd=1) | ||
positive_mu = pm.Deterministic('positive_mu', np.abs(mu)) | ||
pm.Normal('x_obs', mu=mu, sd=1, observed=observed) | ||
prior = pm.sample_generative() | ||
|
||
assert (prior['mu'] < 0).all() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pm.floatX here? |
||
assert (prior['positive_mu'] > 0).all() | ||
assert (prior['x_obs'] < 0).all() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New line? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this further up as it's a major feature. Also, I think we should add author names to who contributed to feature / bugfix.