From 999661c092310b1f247f14037f795a852425e9c9 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Sat, 10 Mar 2018 11:28:43 +0300 Subject: [PATCH] fix PyMC3 variable is not replaced if provided in more_replacements (VI) (#2891) * fixes #2890 * float32 y * update release notes * use floatX --- RELEASE-NOTES.md | 1 + pymc3/tests/test_variational_inference.py | 14 ++++++++++++++ pymc3/variational/opvi.py | 1 + 3 files changed, 16 insertions(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 574ca45f824..2da2c329a6f 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -23,6 +23,7 @@ - `VonMises` does not overflow for large values of kappa. i0 and i1 have been removed and we now use log_i0 to compute the logp. - The bandwidth for KDE plots is computed using a modified version of Scott's rule. The new version uses entropy instead of standard deviation. This works better for multimodal distributions. Functions using KDE plots has a new argument `bw` controlling the bandwidth. +- fix PyMC3 variable is not replaced if provided in more_replacements (#2890) ### Deprecations diff --git a/pymc3/tests/test_variational_inference.py b/pymc3/tests/test_variational_inference.py index e7f43ed0e07..080bf0e1ca3 100644 --- a/pymc3/tests/test_variational_inference.py +++ b/pymc3/tests/test_variational_inference.py @@ -835,6 +835,20 @@ def test_sample_replacements(binomial_model_inference): assert sampled.shape[0] == 101 +def test_var_replacement(): + X_mean = pm.floatX(np.linspace(0, 10, 10)) + y = pm.floatX(np.random.normal(X_mean*4, .05)) + with pm.Model(): + inp = pm.Normal('X', X_mean, shape=X_mean.shape) + coef = pm.Normal('b', 4.) + mean = inp * coef + pm.Normal('y', mean, .1, observed=y) + advi = pm.fit(100) + assert advi.sample_node(mean).eval().shape == (10, ) + x_new = pm.floatX(np.linspace(0, 10, 11)) + assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11, ) + + def test_empirical_from_trace(another_simple_model): with another_simple_model: step = pm.Metropolis() diff --git a/pymc3/variational/opvi.py b/pymc3/variational/opvi.py index 3481b51420a..1d2b851cc04 100644 --- a/pymc3/variational/opvi.py +++ b/pymc3/variational/opvi.py @@ -1459,6 +1459,7 @@ def sample_node(self, node, size=None, sampled node(s) with replacements """ node_in = node + node = theano.clone(node, more_replacements) if size is None: node_out = self.symbolic_single_sample(node) else: