Skip to content

Commit

Permalink
Merge pull request #621 from tillahoffmann/timeout
Browse files Browse the repository at this point in the history
Add timeout parameter.
  • Loading branch information
WardBrian authored Sep 21, 2022
2 parents 1219274 + a71a36d commit c2bab85
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 5 deletions.
57 changes: 52 additions & 5 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
import threading
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

import ujson as json
Expand Down Expand Up @@ -568,6 +569,7 @@ def optimize(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanMLE:
"""
Run the specified CmdStan optimize algorithm to produce a
Expand Down Expand Up @@ -667,6 +669,8 @@ def optimize(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"
:param timeout: Duration at which optimization times out in seconds.
:return: CmdStanMLE object
"""
optimize_args = OptimizeArgs(
Expand Down Expand Up @@ -698,7 +702,13 @@ def optimize(
)
dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(
runset,
dummy_chain_id,
show_console=show_console,
timeout=timeout,
)
runset.raise_for_timeouts()

if not runset._check_retcodes():
msg = "Error during optimization! Command '{}' failed: {}".format(
Expand Down Expand Up @@ -744,6 +754,7 @@ def sample(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
*,
force_one_process_per_chain: Optional[bool] = None,
) -> CmdStanMCMC:
Expand Down Expand Up @@ -941,6 +952,8 @@ def sample(
model was compiled with STAN_THREADS=True, and utilize the
parallel chain functionality if those conditions are met.
:param timeout: Duration at which sampling times out in seconds.
:return: CmdStanMCMC object
"""
if fixed_param is None:
Expand Down Expand Up @@ -1116,6 +1129,7 @@ def sample(
show_progress=show_progress,
show_console=show_console,
progress_hook=progress_hook,
timeout=timeout,
)
if show_progress and progress_hook is not None:
progress_hook("Done", -1) # -1 == all chains finished
Expand All @@ -1131,6 +1145,8 @@ def sample(
sys.stdout.write('\n')
get_logger().info('CmdStan done processing.')

runset.raise_for_timeouts()

get_logger().debug('runset\n%s', repr(runset))

# hack needed to parse CSV files if model has no params
Expand Down Expand Up @@ -1186,6 +1202,7 @@ def generate_quantities(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanGQ:
"""
Run CmdStan's generate_quantities method which runs the generated
Expand Down Expand Up @@ -1244,6 +1261,8 @@ def generate_quantities(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"
:param timeout: Duration at which generation times out in seconds.
:return: CmdStanGQ object
"""
if isinstance(mcmc_sample, CmdStanMCMC):
Expand Down Expand Up @@ -1306,8 +1325,10 @@ def generate_quantities(
runset,
i,
show_console=show_console,
timeout=timeout,
)

runset.raise_for_timeouts()
errors = runset.get_err_msgs()
if errors:
msg = (
Expand Down Expand Up @@ -1343,6 +1364,7 @@ def variational(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanVB:
"""
Run CmdStan's variational inference algorithm to approximate
Expand Down Expand Up @@ -1435,6 +1457,9 @@ def variational(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"
:param timeout: Duration at which variational Bayesian inference times
out in seconds.
:return: CmdStanVB object
"""
variational_args = VariationalArgs(
Expand Down Expand Up @@ -1468,7 +1493,13 @@ def variational(

dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(
runset,
dummy_chain_id,
show_console=show_console,
timeout=timeout,
)
runset.raise_for_timeouts()

# treat failure to converge as failure
transcript_file = runset.stdout_files[dummy_chain_id]
Expand Down Expand Up @@ -1504,9 +1535,8 @@ def variational(
'current value is {}.'.format(grad_samples)
)
else:
msg = (
'Variational algorithm failed.\n '
'Console output:\n{}'.format(contents)
msg = 'Error during variational inference: {}'.format(
runset.get_err_msgs()
)
raise RuntimeError(msg)
# pylint: disable=invalid-name
Expand All @@ -1520,6 +1550,7 @@ def _run_cmdstan(
show_progress: bool = False,
show_console: bool = False,
progress_hook: Optional[Callable[[str, int], None]] = None,
timeout: Optional[float] = None,
) -> None:
"""
Helper function which encapsulates call to CmdStan.
Expand Down Expand Up @@ -1556,6 +1587,20 @@ def _run_cmdstan(
env=os.environ,
universal_newlines=True,
)
if timeout:

def _timer_target() -> None:
# Abort if the process has already terminated.
if proc.poll() is not None:
return
proc.terminate()
runset._set_timeout_flag(idx, True)

timer = threading.Timer(timeout, _timer_target)
timer.daemon = True
timer.start()
else:
timer = None
while proc.poll() is None:
if proc.stdout is not None:
line = proc.stdout.readline()
Expand All @@ -1569,6 +1614,8 @@ def _run_cmdstan(
stdout, _ = proc.communicate()
retcode = proc.returncode
runset._set_retcode(idx, retcode)
if timer:
timer.cancel()

if stdout:
fd_out.write(stdout)
Expand Down
12 changes: 12 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
else:
self._num_procs = 1
self._retcodes = [-1 for _ in range(self._num_procs)]
self._timeout_flags = [False for _ in range(self._num_procs)]
if chain_ids is None:
chain_ids = [i + 1 for i in range(chains)]
self._chain_ids = chain_ids
Expand Down Expand Up @@ -230,6 +231,10 @@ def _set_retcode(self, idx: int, val: int) -> None:
"""Set retcode at process[idx] to val."""
self._retcodes[idx] = val

def _set_timeout_flag(self, idx: int, val: bool) -> None:
"""Set timeout_flag at process[idx] to val."""
self._timeout_flags[idx] = val

def get_err_msgs(self) -> str:
"""Checks console messages for each CmdStan run."""
msgs = []
Expand Down Expand Up @@ -294,3 +299,10 @@ def save_csvfiles(self, dir: Optional[str] = None) -> None:
raise ValueError(
'Cannot save to file: {}'.format(to_path)
) from e

def raise_for_timeouts(self) -> None:
if any(self._timeout_flags):
raise TimeoutError(
f"{sum(self._timeout_flags)} of {self.num_procs} processes "
"timed out"
)
21 changes: 21 additions & 0 deletions test/data/timeout.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
data {
// Indicator for endless looping.
int loop;
}

transformed data {
// Maybe loop forever so the model times out.
real y = 1;
while(loop && y) {
y += 1;
}
}

parameters {
real x;
}

model {
// A nice model so we can get a fit for the `generated_quantities` call.
x ~ normal(0, 1);
}
9 changes: 9 additions & 0 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,15 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
with self.assertRaises(TimeoutError):
timeout_model.generate_quantities(
timeout=0.1, mcmc_sample=fit, data={'loop': 1}
)


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.optimize(data={'loop': 1}, timeout=0.1)


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,12 @@ def test_diagnostics(self):
self.assertEqual(fit.max_treedepths, None)
self.assertEqual(fit.divergences, None)

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.sample(timeout=0.1, chains=1, data={'loop': 1})


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.variational(timeout=0.1, data={'loop': 1})


if __name__ == '__main__':
unittest.main()

0 comments on commit c2bab85

Please sign in to comment.