Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample_prior function #2876

Closed
wants to merge 9 commits into from

Conversation

ColCarroll
Copy link
Member

This allows samples from a model, ignoring all observed variables. See the screenshot below for an example in a simple model.

Right now it relies on unofficial python3.6 behavior, and official python3.7 behavior. Namely, dictionaries keeping insertion order. I would love a suggestion to avoid that requirement, but I can also take a swing at having tree_dict subclass from OrderedDict instead.

image

@springcoil
Copy link
Contributor

springcoil commented Feb 26, 2018 via email

@ColCarroll
Copy link
Member Author

Yep! That would be one use case. Or faster prototyping (for example, seeing if the generated data looks reasonable). We wanted to use something like this last week for generating a toy data set for a gerrymandering project.

@fonnesbeck
Copy link
Member

This paper shows a good example of where you might usefully use prior sampling.

@junpenglao
Copy link
Member

Ha, this is great!!! I was thinking excatly the same in another issue the other day: #2856 (comment)

for _ in indices:
point = {}
for var_name, var in model.named_vars.items():
val = var.distribution.random(point=point, size=size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to use the part from smc where it samples from the prior:
https://github.com/pymc-devs/pymc3/blob/801accb5f236ab9daa89a8fcd9d09a3ba4ed0a39/pymc3/step_methods/smc.py#L186-L193
Otherwise you will get error with bounded RVs:
AttributeError: 'TransformedDistribution' object has no attribute 'random'

@springcoil
Copy link
Contributor

springcoil commented Feb 27, 2018 via email

@ColCarroll
Copy link
Member Author

@junpenglao That's a good sign that we even named the functions the same! You seemed to sketch out a pretty complete method in the comment on the other issue (along with some good edge cases for the test) - I'll hopefully update later today.

@junpenglao
Copy link
Member

I was hoping someone will pick it up ;-)

@junpenglao
Copy link
Member

Also need to add to release note.

@ColCarroll
Copy link
Member Author

@junpenglao i updated to sample correctly from transformed variables. I decided against (for now) using the trick from jitter to get samples from RVs that lack a random method. In the case of a DensityDistribution, it actually gives a pretty informative error message now.

I am a little confused because sampling from a transformed distribution is super slow: changing Normal to HalfNormal (which uses a log transform) in the example I posted above lowers the iterations/s from ~3,800 to 12.

@ColCarroll
Copy link
Member Author

I have tried a few things without much luck to fix the speed problem. I might give a try tomorrow to do something similar to what sampled does and just the ObservedRV into a FreeRV, since the actual sampling is also a lot faster than the sample_prior with transformed variables.

@junpenglao
Copy link
Member

junpenglao commented Feb 28, 2018

I decided against (for now) using the trick from jitter to get samples from RVs that lack a random method.

Agree - that is more for the initialization. After this PR we can replace the jitter function currently used with sample_prior (with jitter etc to handle corner cases).

RE slowness of bounded value, maybe try using var.distribution.transform_used.forward_val? the .eval() part is understandable quite slow.

[Edit]: using forward_val doesnt speed up things currently, but potentially could if we rewrite it into numpy functions.

prior = {var: [] for var in vars}
for _ in indices:
point = {}
for var_name, var in model.named_vars.items():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name_vars also contains Deterministic and Potential, which dont have distribution and random

RELEASE-NOTES.md Outdated
- Plots of discrete distributions in the docstrings
- Add logitnormal distribution
- New function `pm.sample_prior` which generates test data from a model in the absence of data.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this further up as it's a major feature. Also, I think we should add author names to who contributed to feature / bugfix.

@@ -1207,6 +1207,66 @@ def sample_ppc_w(traces, samples=None, models=None, weights=None,
return {k: np.asarray(v) for k, v in ppc.items()}


def sample_prior(samples=500, model=None, vars=None, size=None,
random_seed=None, progressbar=True):
"""Generate samples from the prior of a model.
Copy link
Member

@twiecki twiecki Mar 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a bit more description of why this is useful and when you would use it. It's really the prior predictive we're sampling from here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its helpful -- I think one use case for this function is to generate a unique starting point for each chain, when multiple are required, like in this case #2856

@twiecki
Copy link
Member

twiecki commented Mar 1, 2018

Would be great to add a NB with some motivation and example usage we can add to the docs.

@ColCarroll
Copy link
Member Author

Some of this seems maybe trickier than I thought. I've tried a few methods that are almost clever and don't work.

My current favorite approach tries to clone the whole model, but I am not able to clone an ObservedRV into a FreeRV. In particular, it looks like the previously-observed-variable is hoping to get the inputs from the old model. Here's a test case that fails, as well as the implementation I am describing:

x = np.random.normal(0, 1, size=20) + true_μ

with pm.Model() as model:
    μ = pm.Normal('μ', mu=-10, sd=1)
    shifted_μ = pm.Deterministic('shifted_μ', μ + 10)
    x_obs = pm.Normal('x_obs', mu=μ - 100, sd=1, observed=x)
    sample_prior(draws=1000)

Here is the current implementation of sample_prior:

class PriorModel(pm.Model):
    """This is required when using the context manager because otherwise
         there are duplicate variable names because of the outside pm.Model
    """
    def __new__(cls, *args, **kwargs):
        instance = super(PriorModel, cls).__new__(cls)
        instance._parent = None
        return instance

def sample_prior(**kwargs):
    """Generate samples from the prior of a model.

    All keyword inputs are passed to `pm.sample`, which should be consulted
    for documentation.

    Returns
    -------
    trace : pymc3.backends.base.MultiTrace
        A `MultiTrace` object that contains the samples.
    """
        
    model = modelcontext(kwargs.get('model'))

    with PriorModel() as prior_model:
        prior_model._parent = None
        for var_name in get_default_varnames(model.named_vars, include_transformed=False):
            var = model[var_name]
            replace = {model.named_vars[var_name]: var for var_name, var in prior_model.named_vars.items()}
            cloned = theano.clone(var, replace=replace)
            if var in model.observed_RVs:
                prior_model.Var(var_name, var.distribution)
            elif var in model.free_RVs:
                prior_model.Var(var_name, dist=cloned.distribution)
            elif var in model.deterministics:
                pm.Deterministic(var_name, cloned)
            elif var in model.potentials:
                pm.Potential(var_name, cloned)
            else:
                raise ValueError('Unexpected input {} to prior. Please report as bug!'.format(var_name))
        return pm.sample(**kwargs)

@junpenglao
Copy link
Member

Why is forward pass random (like what you did before) doesnt work? (besides of slowness)

I would like to contribute a bit more to this issue, as efficient forward random is quite important for likelihood-free that I would like to address. Could you share your experiments?

@ColCarroll
Copy link
Member Author

Gosh, it is easy to forget how useful outside input can be sometimes. I am going to focus on that instead of the many hours I spent trying to get something else to work :D

It looks like forward pass continues to work, and I actually fixed the speed problem in a ninja edit last week.

@twiecki would you rather have an example NB along with this PR, or merge this to master to start working more bugs out?

@twiecki
Copy link
Member

twiecki commented Mar 4, 2018

@ColCarroll rather with this one :). the API shouldn't change all that much.

@junpenglao
Copy link
Member

@ColCarroll did you push the new changes?

@ColCarroll
Copy link
Member Author

ColCarroll commented Mar 4, 2018

Yes - the major change is this line for deterministic variables:

        val = var.eval({model.named_vars[v]: point[v] for v in pm.model.get_named_nodes(var)})

(it is complicated because passing unused variables throws an error)

@junpenglao
Copy link
Member

Nice!!! LGTM

@ColCarroll
Copy link
Member Author

Failure looks like something related to float32. I will mark as xfail for now unless someone has a suggestion!

Working on a tiny case study notebook to use as well.

@springcoil
Copy link
Contributor

LGTM I think fine to mark that with xfail, since we often have errors like that. Maybe add in to one of the other notebooks sample prior, that might be easier than doing your own notes.

@shkr
Copy link
Contributor

shkr commented Mar 8, 2018

Awesome update. The fact that now we can check generated data looks reasonable easily, will make development on pymc3 a lot easier

@junpenglao
Copy link
Member

Maybe the test fail could be fix by specifying the dtype? similar to #2891 (comment)?


assert (prior['mu'] < 0).all()
assert (prior['positive_mu'] > 0).all()
assert (prior['x_obs'] < 0).all()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New line?

@ColCarroll
Copy link
Member Author

Updated this to use #2902.

Huge thanks to @lucianopaz, as that code cleans this up a lot, and it looks quite tricky! You might take a look to make sure I did not mess anything up:

-Added a `size` argument to `draw_values`, 
-Made it so that it ignores `observed` variables.  

The first one I think is good, the second one might not be wanted elsewhere.

Note that now we just sample all the points we want from each node as we scan through, so it is quite fast, and no longer uses a progressbar since it is not iterative.

I have confirmed that I can sample from the Efron-Morris baseball generative model, and am going to work on turning that into an actual example notebook.


names = get_default_varnames(model.named_vars, include_transformed=False)
# draw_values fails with auto-transformed variables. transform them later!
values = draw_values([model[name] for name in names], size=samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow this is really efficient! However, is it sure that the values draw in the children of a graph is depending on the samples from their parent? In the previous implementation, we always sample by evaluating a point which contains samples from a higher hierarchy. For example, if b ~ p(a), we sample a_tilde first then sample b~p(a_tilde). Is it the case her also?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple example:

X = theano.shared(np.arange(3))
with pm.Model() as m:
    ind = pm.Categorical('i', np.ones(3)/3)
    x = pm.Deterministic('X', X[ind])
    prior=pm.sample_generative(10)

prior
{'X': array([0, 0, 2, 1, 2, 2, 1, 0, 0, 1]),
 'i': array([1, 0, 0, 2, 2, 0, 2, 1, 0, 0])}

i and X should be identical.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super helpful example! Let me take a look at it -- there's some work already to avoid some edge cases, and I would have thought this got caught.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caught the bug (will add tests for all this, too). Your example runs as desired now!

I make sure I evaluate the params by making a dictionary of index integers to nodes (avoids non-hashability of ndarray). After evaluating the nodes, I was accidentally using the index integer to check if it was a child of another node. This was never true, so I never supplied that value to the rest of the graph.

if size is None:
return func(*values)
else:
return np.array([func(*value) for value in zip(*values)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size seems to only be imposed to param's with a random method, and we hope the content of values to be the right size in the end. Shouldn't there be some enforcement of the size, for the numbers.Number, np.ndarray, tt.TensorConstant, tt.sharedvar.SharedVariable and tt.TensorVariable in point cases for us to be sure that values will in fact have the desired output size?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am relying here on theano catching those sorts of errors, and giving more informative errors than I could. I am running this on a few different models to make sure it gives reasonable results, but so far those sorts of inputs get broadcast in a sensible manner.

to_eval = set(missing_inputs)
missing_inputs = set()
for param in to_eval:
try: # might evaluate in a bad order,
evaluated[param] = _draw_value(params[param], point=point, givens=givens.values(), size=size)
if any(param in j for j in named_nodes_children.values()):
givens[param.name] = (params[param], evaluated[param])
if any(params[param] in j for j in named_nodes_children.values()):
Copy link
Contributor

@lucianopaz lucianopaz Mar 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops I actually commented on an older commit so it shows as outdated, sorry for the mess. First off, it looks very nice, however I think this line is confusing, You're trying to see if the node params[param] is a child of some other named node. If params[param] were to be a named node, that information should be available in the dictionary named_nodes_parents. If params[param] were not to be a named node, then it would not be registered neither in the named_nodes_parents nor the named_nodes_children dictionaries.
If params[param] is a named node, you should be able to replace this line by:
if named_nodes_parents[params[param]]:
If params[param] is not a named node, then I think it shouldn't bee added to givens but I may be overlooking something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much nicer, thank you!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also just checked out travis, and your suggestion will also fix failing tests)

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Are you adding more tests or is this ready?

@ColCarroll
Copy link
Member Author

Not yet - I am now looking at

data = pd.read_table(pm.get_data('efron-morris-75-data.tsv'), sep="\t")
at_bats, hits = data[['At-Bats', 'Hits']].values.T
N = len(hits)

with pm.Model() as baseball_model:
    phi = pm.Beta('phi', alpha=1., beta=1.)

    kappa_log = pm.Exponential('kappa_log', lam=1.5)
    kappa = pm.Deterministic('kappa', tt.exp(kappa_log))

    thetas = pm.Beta('thetas', alpha=phi*kappa, beta=(1.0-phi)*kappa, shape=N)

    y = pm.Binomial('y', n=at_bats, p=thetas, shape=N, observed=hits)
    p = pm.sample_generative()

In particular, p['thetas'].mean(axis=0) * at_bats should be similar to p['y'].mean(axis=0), but I am getting strange outputs like

# thetas * at_bats mean
array([13.1510721 , 12.89432978, 12.87695855, 12.99550508, 12.86672716,
       11.95480332, 12.61286473, 12.68999391, 13.00091407, 13.2292541 ,
       11.99719624, 12.74602259, 11.71467232, 13.29653437, 13.03132627,
       12.04135973, 12.28721424, 12.4750911 ])
# y mean
array([25.346, 38.38 , 37.548, 38.654,  0.86 , 44.902, 44.716, 45.   ,
        7.176, 44.998, 44.872, 44.996, 10.498, 45.   , 39.23 , 45.   ,
       45.   , 33.808])

I would guess there is still something funny going on with passing nodes appropriately.

@junpenglao
Copy link
Member

junpenglao commented Mar 27, 2018

This is a difficult model to generate from. But yeah there seems to be some problem with the last RV y, is it possible it is related to the shape?

@junpenglao junpenglao modified the milestones: 3.4, 3.5 Apr 7, 2018
@junpenglao
Copy link
Member

Since this is currently blocked by #2909, I suggested we rolled back to the original implementation with (slower) forward passing. I have a version that works fairly OK and could serve as a baseline implementation: https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/master/Miscellaneous/Test_sample_prior.ipynb

@springcoil
Copy link
Contributor

I'm confused if this is working or not? :)

@ColCarroll
Copy link
Member Author

Not working yet! But done with PyCon, so back to implementing features, I hope. The intention is to first update draw_values to address #2909 (which #2946 makes a nice step towards doing), and then I think this will work as is (once merge conflicts are resolved).

@springcoil
Copy link
Contributor

Closing this based on the newer PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants