Skip to content

Commit

Permalink
compute accept_stat with Metropolis-adjusted weights
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuurre committed Jul 6, 2022
1 parent 338727c commit fd13399
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/stan/mcmc/hmc/nuts/base_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
double log_sum_weight = 0; // log(exp(H0 - H0))
double H0 = this->hamiltonian_.H(this->z_);
int n_leapfrog = 0;
double log_sum_accept_stat = -std::numeric_limits<double>::infinity();
double accept_stat = 0; // Default accept stat to zero
double p0 = 1; // probability of selecting the initial state

// Build a trajectory until the no-u-turn
// criterion is no longer satisfied
Expand Down Expand Up @@ -164,19 +165,24 @@ class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
++(this->depth_);

if (log_sum_weight_subtree > log_sum_weight) {
p0 = 0.0;
accept_stat
= std::exp(log_sum_accept_stat_subtree - log_sum_weight_subtree);
z_sample = z_propose;
} else {
double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight);
double accept_stat_subtree
= std::exp(log_sum_accept_stat_subtree - log_sum_weight_subtree);
p0 = (1 - accept_prob) * p0;
accept_stat = (1 - accept_prob) * accept_stat
+ accept_prob * accept_stat_subtree;
if (this->rand_uniform_() < accept_prob)
z_sample = z_propose;
}

log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

log_sum_accept_stat
= math::log_sum_exp(log_sum_accept_stat, log_sum_accept_stat_subtree);

// Break when no-u-turn criterion is no longer satisfied
rho = rho_bck + rho_fwd;

Expand All @@ -200,13 +206,9 @@ class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {

this->n_leapfrog_ = n_leapfrog;

double accept_stat = 0; // Default accept stat to zero

// Update accept stat if any subtrees were accepted
if (log_sum_weight > 0) {
// Remove contribution from initial state which is always a perfec accept
log_sum_weight = math::log_diff_exp(log_sum_weight, 0);
accept_stat = std::exp(log_sum_accept_stat - log_sum_weight);
if (p0 < 1) {
// Remove contribution from initial state
accept_stat = accept_stat / (1 - p0);
}

this->z_.ps_point::operator=(z_sample);
Expand Down

0 comments on commit fd13399

Please sign in to comment.