Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register the overloads added by CustomDist in worker processes #7241

Merged
merged 13 commits into from
Dec 3, 2024
Merged
41 changes: 40 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.distributions.custom import CustomDistRV, CustomSymbolicDistRV
from pymc.distributions.distribution import _support_point
from pymc.logprob.abstract import _icdf, _logcdf, _logprob
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
Expand Down Expand Up @@ -383,11 +386,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
# main process and our worker functions
_progress = manager.dict()

# check if model contains CustomDistributions defined without dist argument
custom_methods = _find_custom_dist_dispatch_methods(params[3])

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}

with ProcessPoolExecutor(max_workers=cores) as executor:
with ProcessPoolExecutor(
max_workers=cores,
initializer=_register_custom_methods,
initargs=(custom_methods,),
) as executor:
for c in range(chains): # iterate over the jobs we need to run
# set visible false so we don't have a lot of bars all at once:
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
Expand Down Expand Up @@ -420,3 +430,32 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
)

return tuple(cloudpickle.loads(r.result()) for r in done)


def _find_custom_dist_dispatch_methods(model):
custom_methods = {}
for rv in model.basic_RVs:
rv_type = rv.owner.op
cls = type(rv_type)
if isinstance(rv_type, CustomDistRV | CustomSymbolicDistRV):
custom_methods[cloudpickle.dumps(cls)] = (
cloudpickle.dumps(_logprob.registry.get(cls, None)),
cloudpickle.dumps(_logcdf.registry.get(cls, None)),
cloudpickle.dumps(_icdf.registry.get(cls, None)),
cloudpickle.dumps(_support_point.registry.get(cls, None)),
EliasRas marked this conversation as resolved.
Show resolved Hide resolved
)

return custom_methods


def _register_custom_methods(custom_methods):
for cls, (logprob, logcdf, icdf, support_point) in custom_methods.items():
cls = cloudpickle.loads(cls)
if logprob is not None:
_logprob.register(cls, cloudpickle.loads(logprob))
if logcdf is not None:
_logcdf.register(cls, cloudpickle.loads(logcdf))
if icdf is not None:
_icdf.register(cls, cloudpickle.loads(icdf))
if support_point is not None:
_support_point.register(cls, cloudpickle.loads(support_point))
15 changes: 15 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,21 @@ def test_unobserved_categorical(self):

assert np.all(np.median(trace["mu"], axis=0) == [1, 2])

def test_parallel_custom(self):
def _logp(value, mu):
return -((value - mu) ** 2)

def _random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist)
pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2])
pm.sample_smc(draws=6, cores=2)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
Expand Down
Loading