Skip to content

Commit

Permalink
fix PyMC3 variable is not replaced if provided in more_replacements (…
Browse files Browse the repository at this point in the history
…VI) (#2891)

* fixes #2890

* float32 y

* update release notes

* use floatX
  • Loading branch information
ferrine authored and Junpeng Lao committed Mar 10, 2018
1 parent b385791 commit 999661c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

- `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use log_i0 to compute the logp.
- The bandwidth for KDE plots is computed using a modified version of Scott's rule. The new version uses entropy instead of standard deviation. This works better for multimodal distributions. Functions using KDE plots has a new argument `bw` controlling the bandwidth.
- fix PyMC3 variable is not replaced if provided in more_replacements (#2890)

### Deprecations

Expand Down
14 changes: 14 additions & 0 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,20 @@ def test_sample_replacements(binomial_model_inference):
assert sampled.shape[0] == 101


def test_var_replacement():
X_mean = pm.floatX(np.linspace(0, 10, 10))
y = pm.floatX(np.random.normal(X_mean*4, .05))
with pm.Model():
inp = pm.Normal('X', X_mean, shape=X_mean.shape)
coef = pm.Normal('b', 4.)
mean = inp * coef
pm.Normal('y', mean, .1, observed=y)
advi = pm.fit(100)
assert advi.sample_node(mean).eval().shape == (10, )
x_new = pm.floatX(np.linspace(0, 10, 11))
assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, )


def test_empirical_from_trace(another_simple_model):
with another_simple_model:
step = pm.Metropolis()
Expand Down
1 change: 1 addition & 0 deletions pymc3/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,7 @@ def sample_node(self, node, size=None,
sampled node(s) with replacements
"""
node_in = node
node = theano.clone(node, more_replacements)
if size is None:
node_out = self.symbolic_single_sample(node)
else:
Expand Down

0 comments on commit 999661c

Please sign in to comment.