Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile the functions needed by SMC before the worker processes are started #7472

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
43 changes: 29 additions & 14 deletions pymc/smc/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
compile_pymc,
floatX,
join_nonshared_inputs,
make_shared_replacements,
)
from pymc.sampling.forward import draw
from pymc.step_methods.metropolis import MultivariateNormalProposal
Expand Down Expand Up @@ -168,7 +167,7 @@ def __init__(
raise ValueError(f"Threshold value {threshold} must be between 0 and 1")
self.threshold = threshold
self.model = model
self.rng = np.random.default_rng(seed=random_seed)
self.initialize_rng(random_seed=random_seed)

self.model = modelcontext(model)
self.variables = self.model.value_vars
Expand All @@ -186,6 +185,21 @@ def __init__(
self.resampling_indexes = None
self.weights = np.ones(self.draws) / self.draws

self.varlogp = self.model.varlogp
self.datalogp = self.model.datalogp

def initialize_rng(self, random_seed=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it unnecessary to add this method? Didn't want to directly access SMC_KERNEL.rng since it's not initialized by just direct assignment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't follow, can you explain again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new method SMC_KERNEL.initialize_rng which creates a new SMC_KERNEL.rng with given seed. It's just a convenience method for seeding the rng with a different seed in each worker process. Previously it wasn't necessary because the kernels were created in each process separately and rng is seeded during that. This PR creates the kernel before creating the worker processes and seeding has to be done separately.

I was wondering if adding a new method was unnecessary. The method doesn't do much after all. I could just do smc.rng = np.random.default_rng(seed=random_seed) in _sample_smc_int instead but I didn't want to interact with SMC_KERNEL.rng since I didn't see it used anywhere else outside of SMC_KERNEL.

Nitpicky? I agree.

"""
Initialize random number generator.

Parameters
----------
random_seed : int, array_like of int, RandomState or Generator, optional
Value used to initialize the random number generator.
"""

self.rng = np.random.default_rng(seed=random_seed)

def initialize_population(self) -> dict[str, np.ndarray]:
"""Create an initial population from the prior distribution"""
sys.stdout.write(" ") # see issue #5828
Expand All @@ -212,15 +226,22 @@ def initialize_population(self) -> dict[str, np.ndarray]:

return cast(dict[str, np.ndarray], dict_prior)

def _initialize_kernel(self):
"""Create variables and logp function necessary to run SMC kernel
def _initialize_kernel(self, initial_point=None):
"""
Create variables and logp function necessary to run SMC kernel

This method should not be overwritten. If needed, use `setup_kernel`
instead.

Parameters
----------
initial_point : dict, optional
Dictionary that contains initial values for model variables.
"""
# Create dictionary that stores original variables shape and size
initial_point = self.model.initial_point(random_seed=self.rng.integers(2**30))

if initial_point is None:
# Create dictionary that stores original variables shape and size
initial_point = self.model.initial_point(random_seed=self.rng.integers(2**30))
for v in self.variables:
self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)
# Create particles bijection map
Expand All @@ -237,14 +258,8 @@ def _initialize_kernel(self):
self.tempered_posterior = np.array(floatX(population))

# Initialize prior and likelihood log probabilities
shared = make_shared_replacements(initial_point, self.variables, self.model)

self.prior_logp_func = _logp_forw(
initial_point, [self.model.varlogp], self.variables, shared
)
self.likelihood_logp_func = _logp_forw(
initial_point, [self.model.datalogp], self.variables, shared
)
self.prior_logp_func = _logp_forw(initial_point, [self.varlogp], self.variables, {})
self.likelihood_logp_func = _logp_forw(initial_point, [self.datalogp], self.variables, {})

priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior]
likelihoods = [self.likelihood_logp_func(sample) for sample in self.tempered_posterior]
Expand Down
62 changes: 22 additions & 40 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ def sample_smc(
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)

model = modelcontext(model)
smc = kernel(
draws=draws,
start=start,
model=model,
**kernel_kwargs,
)
initial_points = [
model.initial_point(random_seed=np.random.default_rng(seed=seed).integers(2**30))
for seed in random_seed
]

_log = logging.getLogger(__name__)
_log.info("Initializing SMC sampler...")
Expand All @@ -205,16 +215,9 @@ def sample_smc(
f"in {cores} job{'s' if cores > 1 else ''}"
)

params = (
draws,
kernel,
start,
model,
)

t1 = time.time()

results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores)
results = run_chains(chains, progressbar, smc, random_seed, initial_points, cores)

(
traces,
Expand Down Expand Up @@ -303,41 +306,21 @@ def _save_sample_stats(


def _sample_smc_int(
draws,
kernel,
start,
model,
smc,
random_seed,
initial_point,
chain,
progress_dict,
task_id,
**kernel_kwargs,
):
"""Run one SMC instance."""
in_out_pickled = isinstance(model, bytes)
in_out_pickled = isinstance(smc, bytes)
if in_out_pickled:
# function was called in multiprocessing context, deserialize first
(draws, kernel, start, model) = map(
cloudpickle.loads,
(
draws,
kernel,
start,
model,
),
)

kernel_kwargs = {key: cloudpickle.loads(value) for key, value in kernel_kwargs.items()}

smc = kernel(
draws=draws,
start=start,
model=model,
random_seed=random_seed,
**kernel_kwargs,
)
smc = cloudpickle.loads(smc)

smc._initialize_kernel()
smc.initialize_rng(random_seed)
smc._initialize_kernel(initial_point)
smc.setup_kernel()

stage = 0
Expand Down Expand Up @@ -367,7 +350,7 @@ def _sample_smc_int(
return results


def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
def run_chains(chains, progressbar, smc, random_seed, initial_points, cores):
with CustomProgress(
TextColumn("{task.description}"),
SpinnerColumn(),
Expand All @@ -383,9 +366,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
# main process and our worker functions
_progress = manager.dict()

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}
# "manually" (de)serialize kernel before/after multiprocessing
smc = cloudpickle.dumps(smc)

with ProcessPoolExecutor(max_workers=cores) as executor:
for c in range(chains): # iterate over the jobs we need to run
Expand All @@ -394,12 +376,12 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
futures.append(
executor.submit(
_sample_smc_int,
*params,
smc,
random_seed[c],
initial_points[c],
c,
_progress,
task_id,
**kernel_kwargs,
)
)

Expand Down
15 changes: 15 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ def test_unobserved_categorical(self):

assert np.all(np.median(trace["mu"], axis=0) == [1, 2])

def test_parallel_custom(self):
def _logp(value, mu):
return -((value - mu) ** 2)

def _random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

def _dist(mu, size=None):
return pm.Normal.dist(mu, 1, size=size)

with pm.Model():
mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist)
pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2])
pm.sample_smc(draws=6, cores=2)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
Expand Down