Skip to content

Commit

Permalink
Add split-chain option to rank ecdf plots
Browse files Browse the repository at this point in the history
Related to stan-dev#333
  • Loading branch information
sims1253 committed Dec 16, 2024
1 parent 2ea6f04 commit e9025a1
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 5 deletions.
32 changes: 29 additions & 3 deletions R/mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ mcmc_rank_hist <- function(x,
#' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying if the
#' difference between the observed rank ECDFs and the theoretical expectation
#' should be drawn instead of the unmodified rank ECDF plots.
#' @param split_chains Logical indicating whether to split each chain into two parts.
#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
#' Defaults to `FALSE`.
#' @export
mcmc_rank_ecdf <-
function(x,
Expand All @@ -494,7 +497,8 @@ mcmc_rank_ecdf <-
facet_args = list(),
prob = 0.99,
plot_diff = FALSE,
interpolate_adj = NULL) {
interpolate_adj = NULL,
split_chains = FALSE) {
check_ignored_arguments(...,
ok_args = c("K", "pit", "prob", "plot_diff", "interpolate_adj", "M")
)
Expand All @@ -505,8 +509,28 @@ mcmc_rank_ecdf <-
transformations = transformations,
highlight = 1
)

# Split chains if requested
if (split_chains) {
data$n_chains = data$n_chains/2
data$n_iterations = data$n_iterations/2
n_samples <- length(unique(data$iteration))
midpoint <- n_samples/2

data <- data %>%
group_by(.data$chain) %>%
mutate(
chain = ifelse(
iteration <= midpoint,
paste0(.data$chain, "_1"),
paste0(.data$chain, "_2")
)
) %>%
ungroup()
}

n_iter <- unique(data$n_iterations)
n_chain <- unique(data$n_chains)
n_chain <- length(unique(data$chain))
n_param <- unique(data$n_parameters)

x <- if (is.null(K)) {
Expand Down Expand Up @@ -559,7 +583,9 @@ mcmc_rank_ecdf <-
group = .data$chain
)

scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
# Update legend title based on split_chains
legend_title <- if (split_chains) "Split Chains" else "Chain"
scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chain))

facet_call <- NULL
if (n_param == 1) {
Expand Down
Loading

0 comments on commit e9025a1

Please sign in to comment.