From 9757434b200c246dc3e9347e4b597c538c8ffd05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Sat, 6 Apr 2024 18:57:43 +0300 Subject: [PATCH 01/11] Register custom overloads in all processes --- pymc/smc/sampling.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index d9b76f211ce..6ca18420e3c 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -31,6 +31,8 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace +from pymc.distributions.distribution import _support_point +from pymc.logprob.abstract import _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH @@ -375,11 +377,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): # main process and our worker functions _progress = manager.dict() + # check if model contains CustomDistributions defined without dist argument + custom_methods = _find_custom_methods(params[3]) + # "manually" (de)serialize params before/after multiprocessing params = tuple(cloudpickle.dumps(p) for p in params) kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} - with ProcessPoolExecutor(max_workers=cores) as executor: + with ProcessPoolExecutor( + max_workers=cores, + initializer=_register_custom_methods, + initargs=(custom_methods,), + ) as executor: for c in range(chains): # iterate over the jobs we need to run # set visible false so we don't have a lot of bars all at once: task_id = progress.add_task( @@ -406,3 +415,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id) return tuple(cloudpickle.loads(r.result()) for r in futures) + + +def _find_custom_methods(model): + custom_methods = {} + for rv in model.free_RVs + model.observed_RVs: + cls = rv.owner.op.__class__ + if hasattr(cls, "_random_fn"): + custom_methods[cloudpickle.dumps(cls)] = ( + cloudpickle.dumps(_logprob.registry[cls]), + cloudpickle.dumps(_logcdf.registry[cls]), + cloudpickle.dumps(_support_point.registry[cls]), + ) + + return custom_methods + + +def _register_custom_methods(custom_methods): + for cls, (logprob, logcdf, support_point) in custom_methods.items(): + cls = cloudpickle.loads(cls) + _logprob.register(cls, cloudpickle.loads(logprob)) + _logcdf.register(cls, cloudpickle.loads(logcdf)) + _support_point.register(cls, cloudpickle.loads(support_point)) From 54ca772b33e9b5a450b72ab7ccb063845b99e17f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Sat, 6 Apr 2024 18:58:46 +0300 Subject: [PATCH 02/11] Test for using CustomDist in multiprocess smc sampling --- tests/smc/test_smc.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 33c8718eae8..fc6623926d9 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -133,6 +133,18 @@ def test_unobserved_categorical(self): assert np.all(np.median(trace["mu"], axis=0) == [1, 2]) + def test_parallel_custom(self): + def _logp(value, mu): + return -((value - mu) ** 2) + + def _random(mu, rng=None, size=None): + return rng.normal(loc=mu, scale=1, size=size) + + with pm.Model(): + mu = pm.CustomDist("mu", 0, logp=_logp, random=_random) + pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2]) + pm.sample_smc(draws=6, cores=2) + def test_marginal_likelihood(self): """ Verifies that the log marginal likelihood function From f96152d62268c424916dd057a47e51b91691ed23 Mon Sep 17 00:00:00 2001 From: EliasRas <89061857+EliasRas@users.noreply.github.com> Date: Mon, 27 May 2024 15:15:40 +0300 Subject: [PATCH 03/11] Use basic_RVs instead of manually collecting variables Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/smc/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 876feaf2981..63cd2459a88 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -427,7 +427,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): def _find_custom_methods(model): custom_methods = {} - for rv in model.free_RVs + model.observed_RVs: + for rv in model.basic_RVs: cls = rv.owner.op.__class__ if hasattr(cls, "_random_fn"): custom_methods[cloudpickle.dumps(cls)] = ( From f539b9fac8532322ca4396e279e3006d1242ce9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Mon, 27 May 2024 15:38:47 +0300 Subject: [PATCH 04/11] Use CustomDistRV to find variables defined using CustomDist --- pymc/smc/sampling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 63cd2459a88..e9a3c62ef4f 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -36,7 +36,7 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace -from pymc.distributions.distribution import _support_point +from pymc.distributions.distribution import CustomDistRV, _support_point from pymc.logprob.abstract import _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count @@ -428,8 +428,9 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): def _find_custom_methods(model): custom_methods = {} for rv in model.basic_RVs: - cls = rv.owner.op.__class__ - if hasattr(cls, "_random_fn"): + rv_type = rv.owner.op + cls = rv_type.__class__ + if isinstance(rv_type, CustomDistRV): custom_methods[cloudpickle.dumps(cls)] = ( cloudpickle.dumps(_logprob.registry[cls]), cloudpickle.dumps(_logcdf.registry[cls]), From d6cfb89a091c8f9aa2edc2f8b9b245d6b7c238aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Tue, 28 May 2024 10:37:17 +0300 Subject: [PATCH 05/11] Handle missing overloads --- pymc/smc/sampling.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index e9a3c62ef4f..a23fa49d282 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -432,9 +432,9 @@ def _find_custom_methods(model): cls = rv_type.__class__ if isinstance(rv_type, CustomDistRV): custom_methods[cloudpickle.dumps(cls)] = ( - cloudpickle.dumps(_logprob.registry[cls]), - cloudpickle.dumps(_logcdf.registry[cls]), - cloudpickle.dumps(_support_point.registry[cls]), + cloudpickle.dumps(_logprob.registry.get(cls, None)), + cloudpickle.dumps(_logcdf.registry.get(cls, None)), + cloudpickle.dumps(_support_point.registry.get(cls, None)), ) return custom_methods @@ -443,6 +443,9 @@ def _find_custom_methods(model): def _register_custom_methods(custom_methods): for cls, (logprob, logcdf, support_point) in custom_methods.items(): cls = cloudpickle.loads(cls) - _logprob.register(cls, cloudpickle.loads(logprob)) - _logcdf.register(cls, cloudpickle.loads(logcdf)) - _support_point.register(cls, cloudpickle.loads(support_point)) + if logprob is not None: + _logprob.register(cls, cloudpickle.loads(logprob)) + if logcdf is not None: + _logcdf.register(cls, cloudpickle.loads(logcdf)) + if support_point is not None: + _support_point.register(cls, cloudpickle.loads(support_point)) From 5b460e436d108913201856a268e7328aa64f13ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Tue, 28 May 2024 10:38:02 +0300 Subject: [PATCH 06/11] Also catch the variable defined using CustomDist and the dist argument --- pymc/smc/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index a23fa49d282..7b97c05179c 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -36,7 +36,7 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace -from pymc.distributions.distribution import CustomDistRV, _support_point +from pymc.distributions.distribution import CustomDistRV, CustomSymbolicDistRV, _support_point from pymc.logprob.abstract import _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count @@ -430,7 +430,7 @@ def _find_custom_methods(model): for rv in model.basic_RVs: rv_type = rv.owner.op cls = rv_type.__class__ - if isinstance(rv_type, CustomDistRV): + if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV): custom_methods[cloudpickle.dumps(cls)] = ( cloudpickle.dumps(_logprob.registry.get(cls, None)), cloudpickle.dumps(_logcdf.registry.get(cls, None)), From 997d730bec2e277a03c2b48e4765570403708c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Tue, 28 May 2024 10:38:42 +0300 Subject: [PATCH 07/11] Check both kinds of CustomDist in multiprocessing test --- tests/smc/test_smc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index fc6623926d9..f49ca75d9f8 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -140,8 +140,11 @@ def _logp(value, mu): def _random(mu, rng=None, size=None): return rng.normal(loc=mu, scale=1, size=size) + def _dist(mu, size=None): + return pm.Normal.dist(mu, 1, size=size) + with pm.Model(): - mu = pm.CustomDist("mu", 0, logp=_logp, random=_random) + mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist) pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2]) pm.sample_smc(draws=6, cores=2) From c5725420dd261730df0977b46255a66a7fbac93b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Wed, 17 Jul 2024 08:31:46 +0300 Subject: [PATCH 08/11] Fix CustomDistRV imports --- pymc/smc/sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 6b789835fdb..bf53f94278f 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -36,7 +36,8 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace -from pymc.distributions.distribution import CustomDistRV, CustomSymbolicDistRV, _support_point +from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV +from pymc.distributions.distribution import _support_point from pymc.logprob.abstract import _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count From 0e4c01c6693d9ddaed0cef4d7e6e0c839342b05b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Wed, 17 Jul 2024 11:18:00 +0300 Subject: [PATCH 09/11] Better function name --- pymc/smc/sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index bf53f94278f..75c21e9a629 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -387,7 +387,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): _progress = manager.dict() # check if model contains CustomDistributions defined without dist argument - custom_methods = _find_custom_methods(params[3]) + custom_methods = _find_custom_dist_dispatch_methods(params[3]) # "manually" (de)serialize params before/after multiprocessing params = tuple(cloudpickle.dumps(p) for p in params) @@ -432,7 +432,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): return tuple(cloudpickle.loads(r.result()) for r in done) -def _find_custom_methods(model): +def _find_custom_dist_dispatch_methods(model): custom_methods = {} for rv in model.basic_RVs: rv_type = rv.owner.op From bb3e16f40836b070687ba408ce0b44531e843df7 Mon Sep 17 00:00:00 2001 From: EliasRas <89061857+EliasRas@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:21:13 +0300 Subject: [PATCH 10/11] Use type instead of __class__ Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/smc/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 75c21e9a629..3e23b23b28b 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -436,7 +436,7 @@ def _find_custom_dist_dispatch_methods(model): custom_methods = {} for rv in model.basic_RVs: rv_type = rv.owner.op - cls = rv_type.__class__ + cls = type(rv_type) if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV): custom_methods[cloudpickle.dumps(cls)] = ( cloudpickle.dumps(_logprob.registry.get(cls, None)), From d2669f23af220a9b594cae440ba804482a610b18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elias=20R=C3=A4s=C3=A4nen?= Date: Wed, 17 Jul 2024 11:28:32 +0300 Subject: [PATCH 11/11] Also register icdf overloads --- pymc/smc/sampling.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 3e23b23b28b..f1a4a315bcc 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -38,7 +38,7 @@ from pymc.backends.base import MultiTrace from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV from pymc.distributions.distribution import _support_point -from pymc.logprob.abstract import _logcdf, _logprob +from pymc.logprob.abstract import _icdf, _logcdf, _logprob from pymc.model import Model, modelcontext from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH @@ -441,6 +441,7 @@ def _find_custom_dist_dispatch_methods(model): custom_methods[cloudpickle.dumps(cls)] = ( cloudpickle.dumps(_logprob.registry.get(cls, None)), cloudpickle.dumps(_logcdf.registry.get(cls, None)), + cloudpickle.dumps(_icdf.registry.get(cls, None)), cloudpickle.dumps(_support_point.registry.get(cls, None)), ) @@ -448,11 +449,13 @@ def _find_custom_dist_dispatch_methods(model): def _register_custom_methods(custom_methods): - for cls, (logprob, logcdf, support_point) in custom_methods.items(): + for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items(): cls = cloudpickle.loads(cls) if logprob is not None: _logprob.register(cls, cloudpickle.loads(logprob)) if logcdf is not None: _logcdf.register(cls, cloudpickle.loads(logcdf)) + if icdf is not None: + _icdf.register(cls, cloudpickle.loads(icdf)) if support_point is not None: _support_point.register(cls, cloudpickle.loads(support_point))