Skip to content

Commit

Permalink
PEP8 style, daemonizing child processes (#2747)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and Junpeng Lao committed Dec 11, 2017
1 parent cff9ea9 commit ee7c5bc
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,14 +654,16 @@ def __init__(self, steppers, parallelize):
# configure a child process for each stepper
pm._log.info('Attempting to parallelize chains.')
import multiprocessing
for c,stepper in enumerate(steppers):
for c, stepper in enumerate(tqdm(steppers)):
slave_end, master_end = multiprocessing.Pipe()
stepper_dumps = pickle.dumps(stepper, protocol=4)
process = multiprocessing.Process(
target=self.__class__._run_slave,
args=(c, stepper_dumps, slave_end),
name='ChainWalker{}'.format(c)
)
# we want the child process to exit if the parent is terminated
process.daemon = True
# Starting the process might fail and takes time.
# By doing it in the constructor, the sampling progress bar
# will not be confused by the process start.
Expand Down Expand Up @@ -794,7 +796,7 @@ def _prepare_iter_population(draws, chains, step, start, parallelize, tune=None,

# 1. prepare a BaseTrace for each chain
traces = [_choose_backend(None, chain, model=model) for chain in chains]
for c,strace in enumerate(traces):
for c, strace in enumerate(traces):
# initialize the trace size and variable transforms
if len(strace) > 0:
update_start_vals(start[c], strace.point(-1), model)
Expand Down Expand Up @@ -860,7 +862,7 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
updates = popstep.step(i == tune, points)

# apply the update to the points and record to the traces
for c,strace in enumerate(traces):
for c, strace in enumerate(traces):
if steppers[c].generates_stats:
points[c], states = updates[c]
if strace.supports_sampler_stats:
Expand All @@ -873,17 +875,17 @@ def _iter_population(draws, tune, popstep, steppers, traces, points):
# yield the state of all chains in parallel
yield traces
except KeyboardInterrupt:
for c,strace in enumerate(traces):
for c, strace in enumerate(traces):
strace.close()
if hasattr(steppers[c], 'report'):
steppers[c].report._finalize(strace)
raise
except BaseException:
for c,strace in enumerate(traces):
for c, strace in enumerate(traces):
strace.close()
raise
else:
for c,strace in enumerate(traces):
for c, strace in enumerate(traces):
strace.close()
if hasattr(steppers[c], 'report'):
steppers[c].report._finalize(strace)
Expand Down

0 comments on commit ee7c5bc

Please sign in to comment.