diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index cb27ab8538e..78792e375e2 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -11,8 +11,7 @@ Gaussian Process implementation for efficient inference, along with lower-level functions such as `cartesian` and `kronecker` products. - Added `Coregion` covariance function. -- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters. - Optionally it can plot divergences. +- Add new 'pairplot' function, for plotting scatter or hexbin matrices of sampled parameters. Optionally it can plot divergences. - Plots of discrete distributions in the docstrings - Add logitnormal distribution - Densityplot: add support for discrete variables @@ -21,6 +20,7 @@ - Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces. - add test and support for creating multivariate mixture and mixture of mixtures - `distribution.draw_values`, now is also able to draw values from conditionally dependent RVs, such as autotransformed RVs (Refer to PR #2902). +- New function `pm.sample_prior` which generates test data from a model in the absence of data. ### Fixes diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 8139c8dd464..68e70cf0172 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -1,3 +1,4 @@ +import collections import numbers import numpy as np import theano.tensor as tt @@ -210,7 +211,7 @@ def random(self, *args, **kwargs): -def draw_values(params, point=None): +def draw_values(params, point=None, size=None): """ Draw (fix) parameter values. Handles a number of cases: @@ -254,7 +255,7 @@ def draw_values(params, point=None): # Init givens and the stack of nodes to try to `_draw_value` from givens = {} - stored = set([]) # Some nodes + stored = set() # Some nodes stack = list(leaf_nodes.values()) # A queue would be more appropriate while stack: next_ = stack.pop(0) @@ -279,13 +280,14 @@ def draw_values(params, point=None): # The named node's children givens values must also be taken # into account. children = named_nodes_children[next_] - temp_givens = [givens[k] for k in givens.keys() if k in children] + temp_givens = [givens[k] for k in givens if k in children] try: # This may fail for autotransformed RVs, which don't # have the random method givens[next_.name] = (next_, _draw_value(next_, point=point, - givens=temp_givens)) + givens=temp_givens, + size=size)) stored.add(next_.name) except theano.gof.fg.MissingInputError: # The node failed, so we must add the node's parents to @@ -295,10 +297,28 @@ def draw_values(params, point=None): if node is not None and node.name not in stored and node not in params]) - values = [] - for param in params: - values.append(_draw_value(param, point=point, givens=givens.values())) - return values + + # Funny dance to resolve missing nodes + params = dict(enumerate(params)) # some nodes are not hashable + evaluated = {} + to_eval = set() + missing_inputs = set(params) + while to_eval or missing_inputs: + if to_eval == missing_inputs: + raise ValueError('Cannot resolve inputs for {}'.format([str(params[j]) for j in to_eval])) + to_eval = set(missing_inputs) + missing_inputs = set() + for param_idx in to_eval: + param = params[param_idx] + try: # might evaluate in a bad order, + evaluated[param_idx] = _draw_value(param, point=point, givens=givens.values(), size=size) + if isinstance(param, collections.Hashable) and named_nodes_parents.get(param): + givens[param.name] = (param, evaluated[param_idx]) + except theano.gof.fg.MissingInputError: + missing_inputs.add(param_idx) + + + return [evaluated[j] for j in params] # set the order back @memoize @@ -326,7 +346,7 @@ def _compile_theano_function(param, vars, givens=None): allow_input_downcast=True) -def _draw_value(param, point=None, givens=None): +def _draw_value(param, point=None, givens=None, size=None): """Draw a random value from a distribution or return a constant. Parameters @@ -355,14 +375,19 @@ def _draw_value(param, point=None, givens=None): if point and hasattr(param, 'model') and param.name in point: return point[param.name] elif hasattr(param, 'random') and param.random is not None: - return param.random(point=point, size=None) + return param.random(point=point, size=size) + elif hasattr(param, 'distribution') and hasattr(param.distribution, 'random') and param.distribution.random is not None: + return param.distribution.random(point=point, size=size) else: if givens: variables, values = list(zip(*givens)) else: variables = values = [] func = _compile_theano_function(param, variables) - return func(*values) + if size is None: + return func(*values) + else: + return np.array([func(*value) for value in zip(*values)]) else: raise ValueError('Unexpected type in draw_value: %s' % type(param)) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 067c17a8902..38bf66593cb 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -11,11 +11,12 @@ from .backends.base import BaseTrace, MultiTrace from .backends.ndarray import NDArray +from .distributions import draw_values from .model import modelcontext, Point, all_continuous from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, CategoricalGibbsMetropolis, Slice, CompoundStep, arraystep) -from .util import update_start_vals +from .util import update_start_vals, is_transformed_name, get_untransformed_name, get_default_varnames from .vartypes import discrete_types from pymc3.step_methods.hmc import quadpotential from pymc3 import plots @@ -25,7 +26,7 @@ import sys sys.setrecursionlimit(10000) -__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts'] +__all__ = ['sample', 'iter_sample', 'sample_ppc', 'sample_ppc_w', 'init_nuts', 'sample_generative'] STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis) @@ -1206,6 +1207,50 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None, return {k: np.asarray(v) for k, v in ppc.items()} +def sample_generative(samples=500, model=None, vars=None, random_seed=None): + """Generate samples from the generative model. + + Parameters + ---------- + samples : int + Number of samples from the prior to generate. Defaults to 500. + model : Model (optional if in `with` context) + vars : iterable + Variables for which to compute the posterior predictive samples. + Defaults to `model.named_vars`. + random_seed : int + Seed for the random number generator. + + Returns + ------- + dict + Dictionary with the variables as keys. The values are arrays of prior samples. + """ + model = modelcontext(model) + + if vars is None: + vars = set(model.named_vars.keys()) + + if random_seed is not None: + np.random.seed(random_seed) + + names = get_default_varnames(model.named_vars, include_transformed=False) + # draw_values fails with auto-transformed variables. transform them later! + values = draw_values([model[name] for name in names], size=samples) + + data = {k: v for k, v in zip(names, values)} + + prior = {} + for var_name in vars: + if var_name in data: + prior[var_name] = data[var_name] + elif is_transformed_name(var_name): + untransformed = get_untransformed_name(var_name) + if untransformed in data: + prior[var_name] = model[untransformed].transformation.forward(data[untransformed]).eval() + return prior + + def init_nuts(init='auto', chains=1, n_init=500000, model=None, random_seed=None, progressbar=True, **kwargs): """Set up the mass matrix initialization for NUTS. diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 30362019c0d..6b4d6392468 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -302,3 +302,17 @@ def test_exec_nuts_init(method): assert len(start) == 2 assert isinstance(start[0], dict) assert 'a' in start[0] and 'b_log__' in start[0] + + +def test_sample_generative(): + observed = np.random.normal(10, 1, size=200) + with pm.Model(): + # Use a prior that's way off to show we're actually sampling from it + mu = pm.Normal('mu', mu=-10, sd=1) + positive_mu = pm.Deterministic('positive_mu', np.abs(mu)) + pm.Normal('x_obs', mu=mu, sd=1, observed=observed) + prior = pm.sample_generative() + + assert (prior['mu'] < 0).all() + assert (prior['positive_mu'] > 0).all() + assert (prior['x_obs'] < 0).all() \ No newline at end of file