From fd133990565c848f9cd149bf4726cb2c4e483e72 Mon Sep 17 00:00:00 2001 From: Niko Huurre Date: Wed, 6 Jul 2022 10:37:52 +0300 Subject: [PATCH] compute accept_stat with Metropolis-adjusted weights --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 588afdc87eb..363948d04ef 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -114,7 +114,8 @@ class base_nuts : public base_hmc { double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); int n_leapfrog = 0; - double log_sum_accept_stat = -std::numeric_limits::infinity(); + double accept_stat = 0; // Default accept stat to zero + double p0 = 1; // probability of selecting the initial state // Build a trajectory until the no-u-turn // criterion is no longer satisfied @@ -164,9 +165,17 @@ class base_nuts : public base_hmc { ++(this->depth_); if (log_sum_weight_subtree > log_sum_weight) { + p0 = 0.0; + accept_stat + = std::exp(log_sum_accept_stat_subtree - log_sum_weight_subtree); z_sample = z_propose; } else { double accept_prob = std::exp(log_sum_weight_subtree - log_sum_weight); + double accept_stat_subtree + = std::exp(log_sum_accept_stat_subtree - log_sum_weight_subtree); + p0 = (1 - accept_prob) * p0; + accept_stat = (1 - accept_prob) * accept_stat + + accept_prob * accept_stat_subtree; if (this->rand_uniform_() < accept_prob) z_sample = z_propose; } @@ -174,9 +183,6 @@ class base_nuts : public base_hmc { log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - log_sum_accept_stat - = math::log_sum_exp(log_sum_accept_stat, log_sum_accept_stat_subtree); - // Break when no-u-turn criterion is no longer satisfied rho = rho_bck + rho_fwd; @@ -200,13 +206,9 @@ class base_nuts : public base_hmc { this->n_leapfrog_ = n_leapfrog; - double accept_stat = 0; // Default accept stat to zero - - // Update accept stat if any subtrees were accepted - if (log_sum_weight > 0) { - // Remove contribution from initial state which is always a perfec accept - log_sum_weight = math::log_diff_exp(log_sum_weight, 0); - accept_stat = std::exp(log_sum_accept_stat - log_sum_weight); + if (p0 < 1) { + // Remove contribution from initial state + accept_stat = accept_stat / (1 - p0); } this->z_.ps_point::operator=(z_sample);