Skip to content

Commit

Permalink
add .use_rvars argument to summarise_draws
Browse files Browse the repository at this point in the history
  • Loading branch information
n-kall committed Mar 7, 2024
1 parent 165fb4f commit 930e7a8
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions R/summarise_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ summarise_draws.default <- function(.x, ...) {
summarise_draws.draws <- function(
.x, ..., .args = list(),
.num_args = getOption("posterior.num_args", list()),
.cores = 1
.cores = 1,
.use_rvars = FALSE
) {

if (ndraws(.x) == 0L) {
Expand Down Expand Up @@ -154,10 +155,10 @@ summarise_draws.draws <- function(
} else {
# default functions
funs <- list(
mean = base::mean,
median = stats::median,
sd = stats::sd,
mad = stats::mad,
mean = mean,
median = median,
sd = sd,
mad = mad,
quantile = quantile2,
rhat = rhat,
ess_bulk = ess_bulk,
Expand All @@ -180,7 +181,7 @@ summarise_draws.draws <- function(
}

if (.cores == 1) {
out <- summarise_draws_helper(.x, funs, .args)
out <- summarise_draws_helper(.x, funs, .args, .use_rvars = .use_rvars)
} else {
.x <- .x[, , variables_x]
n_vars <- length(variables_x)
Expand Down Expand Up @@ -218,15 +219,17 @@ summarise_draws.draws <- function(
X = chunk_list,
fun = summarise_draws_helper,
funs = funs,
.args = .args
.args = .args,
.use_rvars = .use_rvars
)
} else {
summary_list <- parallel::mclapply(
X = chunk_list,
FUN = summarise_draws_helper,
mc.cores = .cores,
funs = funs,
.args = .args
.args = .args,
.use_rvars = .use_rvars
)
}
out <- do.call("rbind", summary_list)
Expand Down Expand Up @@ -327,8 +330,12 @@ empty_draws_summary <- function(dimensions = "variable") {
}


create_summary_list <- function(x, v, funs, .args) {
create_summary_list <- function(x, v, funs, .args, .use_rvars = FALSE) {
draws <- drop_dims_or_classes(x[, , v], dims = 3, reset_class = FALSE)
if (.use_rvars) {
lw <- weights(x, log = TRUE)
draws <- rvar(draws, with_chains = TRUE, log_weights = lw)
}
v_summary <- named_list(names(funs))
for (m in names(funs)) {
args <- c(list(draws), .args[[m]])
Expand All @@ -337,10 +344,10 @@ create_summary_list <- function(x, v, funs, .args) {
v_summary
}

summarise_draws_helper <- function(x, funs, .args) {
summarise_draws_helper <- function(x, funs, .args, .use_rvars = FALSE) {
variables_x <- variables(x)
# get length and output names, calculated on the first variable
out_1 <- create_summary_list(x, variables_x[1], funs, .args)
out_1 <- create_summary_list(x, variables_x[1], funs, .args, .use_rvars = .use_rvars)
the_names <- vector(mode = "list", length = length(funs))
for (i in seq_along(out_1)){
if (rlang::is_named(out_1[[i]])) {
Expand All @@ -363,7 +370,7 @@ summarise_draws_helper <- function(x, funs, .args) {
# Do the computation for all remaining variables
if (length(variables_x) > 1L) {
for (v_ind in 2:length(variables_x)) {
out_v <- create_summary_list(x, variables_x[v_ind], funs, .args)
out_v <- create_summary_list(x, variables_x[v_ind], funs, .args, .use_rvars = .use_rvars)
out[v_ind, ] <- unlist(out_v)
}
}
Expand Down

0 comments on commit 930e7a8

Please sign in to comment.