Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/1263 update stansummary to report rank-normalized ESS tail, ESS bulk, max abs deviation(MAD), and Rhat #1290

Open
wants to merge 78 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
20d3304
stansummary computes r_hat_bulk and r_hat_tail; needs unit tests
mitzimorris Jul 11, 2024
0660ba2
Merge branch 'develop' of https://github.com/stan-dev/cmdstan into fe…
mitzimorris Aug 4, 2024
ea04611
checkpointing; compiles, unit tests failing
mitzimorris Aug 5, 2024
de84f19
checkpointing; runs, unit tests need updating
mitzimorris Aug 5, 2024
e436053
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 5, 2024
29b8836
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Aug 5, 2024
0c8ec04
code clenup
mitzimorris Sep 24, 2024
94aa73f
code clenup
mitzimorris Sep 24, 2024
db8ea93
deprecate print
mitzimorris Sep 24, 2024
ed2290c
Merge commit '33e27a4f7b72c7df558d982a33bfe33ce0b14211' into HEAD
yashikno Sep 24, 2024
376e779
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 24, 2024
15ff1c5
stansummary unit tests passing
mitzimorris Sep 25, 2024
a371be1
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Sep 25, 2024
0930e39
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 25, 2024
d4aea84
cleanup - use chainset everywhere
mitzimorris Sep 27, 2024
d334fc1
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Sep 27, 2024
1fcd69a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 27, 2024
1335785
updating diagnose command
mitzimorris Sep 28, 2024
fb63902
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Sep 28, 2024
1d00b27
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 29, 2024
f295b95
summary col widths fix
mitzimorris Sep 30, 2024
aa478b9
passing unit tests
mitzimorris Sep 30, 2024
89594de
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Sep 30, 2024
5de0d39
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Sep 30, 2024
ec5f3e5
change N_Eff to ESS, update tests
mitzimorris Oct 2, 2024
cec9206
merge fix
mitzimorris Oct 2, 2024
858dcfa
Merge commit '5bb4ffc20df9fb0416a8ef11e37f8029a94b7af2' into HEAD
yashikno Oct 2, 2024
99e9ef6
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 2, 2024
41d561a
checkpointing
mitzimorris Oct 11, 2024
8246f39
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 11, 2024
d7cb4f4
all unit tests pass
mitzimorris Oct 14, 2024
8ca0721
Merge commit '68838856f67b106e94a028a8268994a5c11c6804' into HEAD
yashikno Oct 14, 2024
bc6e81d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 14, 2024
663f765
stansummary unit tests passing
mitzimorris Oct 24, 2024
923512f
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 24, 2024
3c9e378
Merge commit 'ed2500a1c7a13bd1c905d7caa14690af2172b6cb' into HEAD
yashikno Oct 24, 2024
3013d42
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 24, 2024
8448467
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 24, 2024
a46ba99
use chainset everywhere
mitzimorris Oct 24, 2024
b44a88d
unit test data files
mitzimorris Oct 24, 2024
43c81ed
remove print helper and tests
mitzimorris Oct 25, 2024
632d6ad
changes per code review
mitzimorris Oct 25, 2024
448da11
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 25, 2024
31663d6
changes per code review
mitzimorris Oct 25, 2024
5b4bb54
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 25, 2024
8af8151
more print cleanup
mitzimorris Oct 25, 2024
a53d3c4
checkpointing
mitzimorris Oct 26, 2024
2efe9ff
checkpointing
mitzimorris Oct 26, 2024
277b7f9
multi-dim container outputs in row-major order
mitzimorris Oct 26, 2024
149b03c
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 26, 2024
ab4420b
simplified reorder row_major logic
mitzimorris Oct 26, 2024
f900462
simplified reorder row_major logic
mitzimorris Oct 26, 2024
7f0cfd1
test cleanup
mitzimorris Oct 26, 2024
a92928e
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 26, 2024
774f898
simplified reorder row_major logic
mitzimorris Oct 26, 2024
43ef763
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 26, 2024
48cc4df
simplified reorder row_major logic
mitzimorris Oct 26, 2024
09ed196
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 26, 2024
69e2c13
diagnose - stricter ESS test
mitzimorris Oct 27, 2024
27bfd88
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 27, 2024
39709ef
changes per code review
mitzimorris Oct 28, 2024
1daadab
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 28, 2024
01defba
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 28, 2024
8a19f64
keep print for now
mitzimorris Oct 28, 2024
fb46a54
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 28, 2024
695114f
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 28, 2024
b7ac1bc
keep print for now
mitzimorris Oct 29, 2024
22f8176
redo logic for -i flag
mitzimorris Oct 30, 2024
04b0016
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 30, 2024
98a3abc
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 30, 2024
b9989c5
fix row_major reorder logic
mitzimorris Oct 30, 2024
593cd58
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 30, 2024
00e2cb4
Merge branch 'feature/1263-new-rhat-summary' of https://github.com/st…
mitzimorris Oct 30, 2024
4055d43
remove unused var
mitzimorris Oct 30, 2024
cb7cf3b
fix row_major reorder logic
mitzimorris Oct 30, 2024
7c52174
checkpointing
mitzimorris Oct 31, 2024
386e13d
Merge commit '84fb5439bd0c9b39db0d55f232bc9cd9bf4af9d5' into HEAD
yashikno Oct 31, 2024
3763b7d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cmdstan/command_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ context_vector get_vec_var_context(const std::string &file, size_t num_chains,
= std::string("\"" + file_1 + "\" and base file \"" + file + "\"");
std::stringstream msg;
msg << "Searching for \"" << file_name_err << std::endl;
msg << "Can't open either of specified files," << file_name_err
msg << "Can't open either of specified files, " << file_name_err
<< std::endl;
throw std::invalid_argument(msg.str());
} else {
Expand Down
110 changes: 56 additions & 54 deletions src/cmdstan/diagnose.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#include <cmdstan/return_codes.hpp>
#include <cmdstan/stansummary_helper.hpp>
#include <stan/mcmc/chains.hpp>
#include <stan/mcmc/chainset.hpp>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <ios>
#include <iostream>

double RHAT_MAX = 1.05;
using cmdstan::return_codes;

double RHAT_MAX = 1.01499; // round to 1.01

void diagnose_usage() {
std::cout << "USAGE: diagnose <filename 1> [<filename 2> ... <filename N>]"
Expand All @@ -26,7 +29,7 @@ void diagnose_usage() {
int main(int argc, const char *argv[]) {
if (argc == 1) {
diagnose_usage();
return 0;
return return_codes::OK;
}

// Parse any arguments specifying filenames
Expand All @@ -45,49 +48,47 @@ int main(int argc, const char *argv[]) {

if (!filenames.size()) {
std::cout << "No valid input files, exiting." << std::endl;
return 0;
return return_codes::NOT_OK;
}

std::cout << std::fixed << std::setprecision(2);

// Parse specified files
std::cout << "Processing csv files: " << filenames[0];
ifstream.open(filenames[0].c_str());

stan::io::stan_csv stan_csv
= stan::io::stan_csv_reader::parse(ifstream, &std::cout);
stan::mcmc::chains<> chains(stan_csv);
ifstream.close();

if (filenames.size() > 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;

for (std::vector<std::string>::size_type chain = 1; chain < filenames.size();
++chain) {
std::cout << filenames[chain];
ifstream.open(filenames[chain].c_str());
stan_csv = stan::io::stan_csv_reader::parse(ifstream, &std::cout);
chains.add(stan_csv);
ifstream.close();
if (chain < filenames.size() - 1)
std::cout << ", ";
else
std::cout << std::endl << std::endl;
std::vector<stan::io::stan_csv> csv_parsed;
for (int i = 0; i < filenames.size(); ++i) {
std::ifstream infile;
std::stringstream out;
stan::io::stan_csv sample;
infile.open(filenames[i].c_str());
try {
sample = stan::io::stan_csv_reader::parse(infile, &out);
// csv_reader warnings are errors - fail fast.
if (!out.str().empty()) {
throw std::invalid_argument(out.str());
}
csv_parsed.push_back(sample);
} catch (const std::invalid_argument &e) {
std::cout << "Cannot parse input csv file: " << filenames[i] << e.what()
<< "." << std::endl;
return return_codes::NOT_OK;
}
}

stan::mcmc::chainset chains(csv_parsed);
stan::io::stan_csv_metadata metadata = csv_parsed[0].metadata;
std::vector<std::string> param_names = csv_parsed[0].header;
size_t num_params = param_names.size();
int num_samples = chains.num_samples();
std::vector<std::string> bad_n_eff_names;
std::vector<std::string> bad_rhat_names;
bool has_errors = false;

for (int i = 0; i < chains.num_params(); ++i) {
if (chains.param_name(i) == std::string("treedepth__")) {
for (int i = 0; i < num_params; ++i) {
if (param_names[i] == std::string("treedepth__")) {
std::cout << "Checking sampler transitions treedepth." << std::endl;
int max_limit = stan_csv.metadata.max_depth;
int max_limit = metadata.max_depth;
long n_max = 0;
Eigen::VectorXd t_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd t_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
for (long n = 0; n < t_samples.size(); ++n) {
if (t_samples(n) >= max_limit) {
++n_max;
Expand All @@ -109,7 +110,7 @@ int main(int argc, const char *argv[]) {
std::cout << "Treedepth satisfactory for all transitions." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("divergent__")) {
} else if (param_names[i] == std::string("divergent__")) {
std::cout << "Checking sampler transitions for divergences." << std::endl;
int n_divergent = chains.samples(i).sum();
if (n_divergent > 0) {
Expand All @@ -129,26 +130,22 @@ int main(int argc, const char *argv[]) {
std::cout << "No divergent transitions found." << std::endl
<< std::endl;
}
} else if (chains.param_name(i) == std::string("energy__")) {
} else if (param_names[i] == std::string("energy__")) {
std::cout << "Checking E-BFMI - sampler transitions HMC potential energy."
<< std::endl;
Eigen::VectorXd e_samples = chains.samples(i);
Eigen::MatrixXd draws = chains.samples(i);
Eigen::VectorXd e_samples
= Eigen::Map<Eigen::VectorXd>(draws.data(), draws.size());
double delta_e_sq_mean = 0;
double e_mean = 0;
double e_var = 0;
e_mean += e_samples(0);
e_var += e_samples(0) * (e_samples(0) - e_mean);
double e_mean = chains.mean(i);
double e_var = chains.variance(i);
for (long n = 1; n < e_samples.size(); ++n) {
double e = e_samples(n);
double delta_e_sq = (e - e_samples(n - 1)) * (e - e_samples(n - 1));
double d = delta_e_sq - delta_e_sq_mean;
delta_e_sq_mean += d / n;
d = e - e_mean;
e_mean += d / (n + 1);
e_var += d * (e - e_mean);
}

e_var /= static_cast<double>(e_samples.size() - 1);
double e_bfmi = delta_e_sq_mean / e_var;
double e_bfmi_threshold = 0.3;
if (e_bfmi < e_bfmi_threshold) {
Expand All @@ -163,14 +160,16 @@ int main(int argc, const char *argv[]) {
} else {
std::cout << "E-BFMI satisfactory." << std::endl << std::endl;
}
} else if (chains.param_name(i).find("__") == std::string::npos) {
double n_eff = chains.effective_sample_size(i);
} else if (param_names[i].find("__") == std::string::npos) {
auto [ess_bulk, ess_tail] = chains.split_rank_normalized_ess(i);
double n_eff = ess_bulk < ess_tail ? ess_bulk : ess_tail;
if (n_eff / num_samples < 0.001)
bad_n_eff_names.push_back(chains.param_name(i));
bad_n_eff_names.push_back(param_names[i]);

double split_rhat = chains.split_potential_scale_reduction(i);
auto [rhat_bulk, rhat_tail] = chains.split_rank_normalized_rhat(i);
double split_rhat = rhat_bulk > rhat_tail ? rhat_bulk : rhat_tail;
if (split_rhat > RHAT_MAX)
bad_rhat_names.push_back(chains.param_name(i));
bad_rhat_names.push_back(param_names[i]);
}
}
if (bad_n_eff_names.size() > 0) {
Expand All @@ -187,13 +186,15 @@ int main(int argc, const char *argv[]) {
<< " may be substantially lower than quoted." << std::endl
<< std::endl;
} else {
std::cout << "Effective sample size satisfactory." << std::endl
std::cout << "Rank-normalized split effective sample size satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}

if (bad_rhat_names.size() > 0) {
has_errors = true;
std::cout << "The following parameters had split R-hat greater than "
std::cout << "The following parameters had rank-normalized split R-hat "
"greater than "
<< RHAT_MAX << ":" << std::endl;
std::cout << " ";
for (size_t n = 0; n < bad_rhat_names.size() - 1; ++n)
Expand All @@ -207,13 +208,14 @@ int main(int argc, const char *argv[]) {
<< " effective parameterization." << std::endl
<< std::endl;
} else {
std::cout << "Split R-hat values satisfactory all parameters." << std::endl
std::cout << "Rank-normalized split R-hat values satisfactory "
<< "for all parameters." << std::endl
<< std::endl;
}
if (!has_errors)
std::cout << "Processing complete, no problems detected." << std::endl;
else
std::cout << "Processing complete." << std::endl;

return 0;
return return_codes::OK;
}
Loading