diff --git a/NEWS.md b/NEWS.md index bb090c96..51ad34f0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,11 @@ Items for next release go here +### Breaking changes + +* `mcmc_pairs()` now returns a ggplot object rather than a `bayesplot_grid` + object. Code that modified the output of `mcmc_pairs()` will need to change. + (#268, @billdenney) # bayesplot 1.8.0 diff --git a/R/mcmc-distributions.R b/R/mcmc-distributions.R index 63a63c95..81d50712 100644 --- a/R/mcmc-distributions.R +++ b/R/mcmc-distributions.R @@ -323,12 +323,8 @@ mcmc_violin <- function(x, n_param <- num_params(data) graph <- ggplot(data, aes(x = ~ value)) + - geom_histogram( - set_hist_aes(freq), - fill = get_color("mid"), - color = get_color("mid_highlight"), - size = .25, - na.rm = TRUE, + geom_mcmc_hist( + freq = freq, binwidth = binwidth, breaks = breaks ) @@ -450,3 +446,4 @@ mcmc_violin <- function(x, yaxis_title(on = n_param == 1 && violin) + xaxis_title(on = n_param == 1) } + diff --git a/R/mcmc-scatterplots.R b/R/mcmc-scatterplots.R index 443e795a..466bb224 100644 --- a/R/mcmc-scatterplots.R +++ b/R/mcmc-scatterplots.R @@ -323,16 +323,23 @@ mcmc_pairs <- function(x, diag_fun <- match.arg(diag_fun) off_diag_fun <- match.arg(off_diag_fun) - plot_diagonal <- pairs_plotfun(diag_fun) - plot_off_diagonal <- pairs_plotfun(off_diag_fun) - - x <- prepare_mcmc_array(x, pars, regex_pars, transformations) - x <- drop_constants_and_duplicates(x) - - n_iter <- num_iters(x) - n_chain <- num_chains(x) - n_param <- num_params(x) - pars <- parameter_names(x) + geom_diagonal <- pairs_geomfun(diag_fun) + geom_off_diagonal <- pairs_geomfun(off_diag_fun) + + x_orig <- x + x_mcmc <- prepare_mcmc_array(x_orig, pars, regex_pars, transformations) + x_cleaned <- drop_constants_and_duplicates(x_mcmc) + + n_iter <- num_iters(x_cleaned) + n_chain <- num_chains(x_cleaned) + n_param <- num_params(x_cleaned) + pars <- parameter_names(x_cleaned) + # Convert `pars` into a factor with the same order as the pars are extracted + # so that the ggplot facets will be sorted in the same order as the pars were + # given (in case the user has a preference for the order of the rows and + # columns). Also, this will match the order of the pars from the prior + # bayesplot_group() version of this function. + pars <- factor(pars, levels=pars) if (n_chain == 1) { warn("Only one chain in 'x'. This plot is more useful with multiple chains.") @@ -344,95 +351,139 @@ mcmc_pairs <- function(x, no_np <- is.null(np) no_lp <- is.null(lp) no_max_td <- is.null(max_treedepth) + # set a default "divergent__" (for divergent transitions) and "max_td_hit__" + # (for max tree depth) values to put into the plotting matrix + divergent__ <- NA + max_td_hit__ <- NA if (!no_np) { param <- sym("Parameter") val <- sym("Value") np <- validate_nuts_data_frame(np, lp) divs <- dplyr::filter(np, UQ(param) == "divergent__") %>% pull(UQ(val)) - divergent__ <- matrix(divs, nrow = n_iter * n_chain, ncol = n_param)[, 1] + # Make divergent transitions into a factor with "NoDiv" or "Div" instead of + # 0/1 or FALSE/TRUE + divergent__ <- + factor( + as.logical(matrix(divs, nrow = n_iter * n_chain, ncol = n_param)[, 1]), + levels = c(FALSE, TRUE), + labels = c("NoDiv", "Div") + ) if (!no_max_td) { gt_max_td <- (dplyr::filter(np, UQ(param) == "treedepth__") %>% pull(UQ(val))) > max_treedepth - max_td_hit__ <- matrix(gt_max_td, nrow = n_iter * n_chain, ncol = n_param)[, 1] + # Make maximum tree depth hits into a factor with "NoHit" or "Hit" instead + # of 0/1 or FALSE/TRUE + max_td_hit__ <- + factor( + matrix(gt_max_td, nrow = n_iter * n_chain, ncol = n_param)[, 1], + levels = c(FALSE, TRUE), + labels = c("NoHit", "Hit") + ) } } - cond <- handle_condition(x, condition, np, lp) + cond <- handle_condition(x_cleaned, condition, np, lp) x <- merge_chains(cond[["x"]]) mark <- cond[["mark"]] - all_pairs <- expand.grid(pars, pars, - stringsAsFactors = FALSE, - KEEP.OUT.ATTRS = FALSE) - plots <- vector("list", length = nrow(all_pairs)) - use_default_binwidth <- is.null(diag_args[["binwidth"]]) - for (j in seq_len(nrow(all_pairs))) { - pair <- as.character(all_pairs[j,]) - - if (identical(pair[1], pair[2])) { - # Diagonal - diag_args[["x"]] <- x[, pair[1], drop = FALSE] - - # silence ggplot2's "Pick better value with `binwidth`" message - if (diag_fun == "hist" && use_default_binwidth) - diag_args[["binwidth"]] <- diff(range(diag_args[["x"]]))/30 - - plots[[j]] <- - do.call(plot_diagonal, diag_args) + - labs(subtitle = pair[1]) + - theme(axis.line.y = element_blank(), - plot.subtitle = element_text(hjust = 0.5)) - - } else { - # Off-diagonal - - # use mark if above diagonal and !mark if below the diagonal - mark2 <- if (is_lower_tri(j, n_param)) !mark else mark - x_j <- x[mark2, pair, drop = FALSE] - - if (!no_np) { - divs_j <- divergent__[mark2] - max_td_hit_j <- if (no_max_td) NULL else max_td_hit__[mark2] + # ensure that a version of x is a data.frame and not a matrix or a tibble or + # anything else + x_plot <- as.data.frame(x) + # Generate column names that are guaranteed not to interfere with other + # columns for plotting + name_base <- max(as.character(pars)) + # Use make.names() to ensure that the creation of the formula in facet_grid() + # works correctly below. + fname1 <- make.names(paste0(name_base, "1")) + fname2 <- make.names(paste0(name_base, "2")) + div_name <- make.names(paste0(name_base, "3")) + td_name <- make.names(paste0(name_base, "4")) + # It is okay to assign whatever is in divergent__ or max_td_hit__ because they + # will be NA if no_np. + x_plot[[div_name]] <- divergent__ + x_plot[[td_name]] <- max_td_hit__ + current_plot <- + ggplot(x_plot) + + facet_grid(rows=as.formula(paste(fname1, "~", fname2)), scales="free") + + labs(x=NULL, y=NULL) + # Generate all pairs of plots + for (nm1 in pars) { + for (nm2 in pars) { + current_mapping <- aes() + if (nm1 == nm2) { + range_nm1 <- range(x_plot[[nm1]]) + current_mapping$x <- as.name(nm1) + # Scale the y-value by the range of the observed data so that the height + # of the y-scales match between the diagonal and off-diagonal elements + current_mapping$y <- str2lang(sprintf("..ndensity..*%g", diff(range_nm1))) + current_plot <- + current_plot + + geom_diagonal( + # Create the facet-generating function to show all data + data=pairs_setfacet(fname1=fname1, fvalue1=nm1, fname2=fname2, fvalue2=nm2, select_rows=TRUE), + mapping=current_mapping, + # Shift the y-value up to the minimum of the observed data so that + # the origin of the y-scales match between the diagonal and + # off-diagonal elements + position=position_nudge(y=range_nm1[1]) + ) } else { - divs_j <- max_td_hit_j <- NULL - } - off_diag_args[["x"]] <- x_j - plots[[j]] <- do.call(plot_off_diagonal, off_diag_args) - - if (isTRUE(any(divs_j == 1))) { - divs_j_fac <- factor(as.logical(divs_j), - levels = c(FALSE, TRUE), - labels = c("NoDiv", "Div")) - plots[[j]] <- plots[[j]] + - geom_point( - aes_(color = divs_j_fac, size = divs_j_fac), + current_mapping$x <- as.name(nm2) + current_mapping$y <- as.name(nm1) + current_mapping_div <- current_mapping_td <- current_mapping + current_mapping_div$color <- current_mapping_div$size <- as.name(div_name) + current_mapping_td$color <- current_mapping_td$size <- as.name(td_name) + # Determine if the plot is in the lower or upper triangle + is_lower <- which(pars %in% nm1) < which(pars %in% nm2) + current_plot <- + current_plot + + geom_off_diagonal( + # Create the facet-generating function to show lower or upper data + data= + pairs_setfacet( + fname1=fname1, fvalue1=nm1, + fname2=fname2, fvalue2=nm2, + select_rows=xor(is_lower, mark) + ), + mapping=current_mapping, + show.legend=FALSE + ) + + geom_off_diagonal( + # Create the facet-generating function to show lower or upper data + # only for divergent transitions (select_rows will select zero rows + # if there are no divergent transitions or if no_np). + data= + pairs_setfacet( + fname1=fname1, fvalue1=nm1, + fname2=fname2, fvalue2=nm2, + select_rows=xor(is_lower, mark) & x_plot[[div_name]] %in% "Div" + ), + mapping=current_mapping_div, + show.legend=FALSE, shape = np_style$shape[["div"]], alpha = np_style$alpha[["div"]], na.rm = TRUE - ) - } - if (isTRUE(any(max_td_hit_j == 1))) { - max_td_hit_j_fac <- factor(max_td_hit_j, levels = c(FALSE, TRUE), - labels = c("NoHit", "Hit")) - plots[[j]] <- plots[[j]] + - geom_point( - aes_(color = max_td_hit_j_fac, size = max_td_hit_j_fac), + ) + + geom_off_diagonal( + # Create the facet-generating function to show lower or upper data + # only for max tree depth hits transitions (select_rows will select + # zero rows if there are no max tree depth hits transitions or if + # no_np). + data= + pairs_setfacet( + fname1=fname1, fvalue1=nm1, + fname2=fname2, fvalue2=nm2, + select_rows=xor(is_lower, mark) & x_plot[[td_name]] %in% "Hit" + ), + mapping=current_mapping_td, + show.legend=FALSE, shape = np_style$shape[["td"]], alpha = np_style$alpha[["td"]], na.rm = TRUE - ) + ) + + format_nuts_points(np_style) } - if (isTRUE(any(divs_j == 1)) || - isTRUE(any(max_td_hit_j == 1))) - plots[[j]] <- format_nuts_points(plots[[j]], np_style) } } - - plots <- lapply(plots, function(x) - x + xaxis_title(FALSE) + yaxis_title(FALSE)) - - bayesplot_grid(plots = plots, - legends = FALSE, - grid_args = grid_args, - save_gg_objects = save_gg_objects) + current_plot } @@ -719,13 +770,40 @@ pairs_condition <- function(chains = NULL, draws = NULL, nuts = NULL) { # internal for mcmc_pairs ------------------------------------------------- +#' Generate a function to set the faceting parameters and row selection for data +#' for `mcmc_pairs()` geoms +#' +#' @param fname1,fname2 facet column names +#' @param fvalue1,fvalue2 facet values +#' @param select_rows A logical vector of rows to keep (note, scalar `TRUE` will +#' keep all rows). +#' @return A function that takes in a data.frame and returns a data.frame with +#' columns `fname1` and `fname2` set to `fvalue1` and `fvalue2`. +pairs_setfacet <- function(fname1, fvalue1, fname2, fvalue2, select_rows) { + force(fname1) + force(fname2) + force(fvalue1) + force(fvalue2) + force(select_rows) + function(x) { + # The use of fvalue1 and fvalue2 is so that we can ensure that there is not + # a column name conflict + stopifnot(!any(c(fname1, fname2) %in% names(x))) + x[[fname1]] <- fvalue1 + x[[fname2]] <- fvalue2 + # Filter to just the rows of interest + x[select_rows, ] + } +} + #' Get plotting functions from user-specified #' `diag_fun` and `off_diag_fun` arguments #' +#' @param x User specified `diag_fun` or `off_diag_fun` argument to +#' `mcmc_pairs()` #' @noRd -#' @param x User specified `diag_fun` or `off_diag_fun` argument to `mcmc_pairs()` -pairs_plotfun <- function(x) { - fun <- paste0("mcmc_", x) +pairs_geomfun <- function(x) { + fun <- paste0("geom_mcmc_", x) utils::getFromNamespace(fun, "bayesplot") } @@ -739,56 +817,6 @@ unstack_to_matrix <- function(df, .form) { as.matrix(x) } -#' Check if off-diagonal plot is above or below the diagonal -#' -#' @noRd -#' @param j integer (index) -#' @param n Number of parameters (number of plots = `n^2`) -#' @return `TRUE` if below the diagonal, `FALSE` if above the diagonal -is_lower_tri <- function(j, n) { - idx <- array_idx_j(j, n) - lower_tri <- lower_tri_idx(n) - row_match_found(idx, lower_tri) -} - -#' Get array indices of the jth element in the plot matrix -#' -#' @noRd -#' @param j integer (index) -#' @param n number of parameters (number of plots = n^2) -#' @return rwo vector (1-row matrix) containing the array indices of the jth -#' element in the plot matrix -array_idx_j <- function(j, n) { - jj <- matrix(seq_len(n^2), nrow = n, byrow = TRUE)[j] - arrayInd(jj, .dim = c(n, n)) -} - -#' Get indices of lower triangular elements of a square matrix -#' @noRd -#' @param n number of rows (columns) in the square matrix -lower_tri_idx <- function(n) { - a <- rev(abs(sequence(seq.int(n - 1)) - n) + 1) - b <- rep.int(seq.int(n - 1), rev(seq.int(n - 1))) - cbind(row = a, col = b) -} - -#' Find which (if any) row in y is a match for x -#' @noRd -#' @param x a row vector (i.e., a matrix with 1 row) -#' @param y a matrix -#' @return either a row number in `y` or `NA` if no match -row_match_found <- function(x, y) { - stopifnot(is.matrix(x), is.matrix(y), nrow(x) == 1) - x <- as.data.frame(x) - y <- as.data.frame(y) - res <- match( - do.call(function(...) paste(..., sep=":::"), x), - do.call(function(...) paste(..., sep=":::"), y) - ) - isTRUE(!is.na(res) && length(res) == 1) -} - - #' Drop any constant or duplicate variables #' @noRd #' @param x 3-D array @@ -893,17 +921,132 @@ handle_condition <- function(x, condition=NULL, np=NULL, lp=NULL) { #' hitting max_treedepth #' #' @noRd -#' @param graph ggplot object #' @param np_args list of style arguments returned by `pairs_style_np()` -#' @return `graph`, updated -format_nuts_points <- function(graph, np_args) { - graph + - scale_color_manual( +#' @return a list of ggplot2 scales +format_nuts_points <- function(np_args) { + list( + ggplot2::scale_color_manual( values = set_names(c(NA, np_args$color[["div"]], NA, np_args$color[["td"]]), c("NoDiv", "Div", "NoHit", "Hit")) - ) + - scale_size_manual( + ), + ggplot2::scale_size_manual( values = set_names(c(0, rel(np_args$size[["div"]]), 0, rel(np_args$size[["td"]])), c("NoDiv", "Div", "NoHit", "Hit")) ) + ) +} + +# geoms for mcmc_pairs #### + +## geoms for mcmc_pairs diagonals #### + +#' Specialized geoms for mcmc plotting with default bayesplot settings +#' +#' @inheritParams geom_freqpoly +#' @param freq Used to set the mapping when mapping is not manually set +#' @param fill,color,size,binwidth,breaks,na.rm Passed to the ggplot2 geom function +#' @param ... Passed to the ggplot2 geom function +geom_mcmc_hist <- function(..., + mapping = set_hist_aes(freq), + np=NULL, + fill = NULL, + color = NULL, + size = NULL, + binwidth = NULL, + bins = NULL, + breaks = NULL, + na.rm = TRUE, + freq = TRUE) { + if (is.null(bins) & is.null(binwidth)) { + # silence ggplot2's "Pick better value with `binwidth`" message + bins <- 30 + } + ggplot2::geom_histogram( + mapping = mapping, + fill = fill %||% get_color("mid"), + color = color %||% get_color("mid_highlight"), + size = size %||% 0.25, + na.rm = na.rm, + binwidth = binwidth, + bins = bins, + breaks = breaks, + ... + ) +} + +#' @describeIn geom_mcmc_hist Density plot with settings for bayesplot by +#' default +geom_mcmc_dens <- function(..., + mapping=NULL, np=NULL, + by_chain=FALSE, + color_chains = FALSE) { + if (is.null(mapping)) { + mapping <- ggplot2::aes() + } + # Note that checking for multiple chains cannot happen within the geom + # function because it doesn't necessarily have access to the data. It will + # raise an error at the time of plotting if no `Chain` column is present in + # the data. + geom_args_extra <- list(...) + if (by_chain) { + mapping[["color"]] <- as.name("Chain") + mapping[["group"]] <- as.name("Chain") + } else { + if (!("fill" %in% names(geom_args_extra))) { + geom_args_extra[["fill"]] <- get_color("mid") + } + if (!any(c("color", "colour") %in% names(geom_args_extra))) { + geom_args_extra[["color"]] <- get_color("mid_highlight") + } + } + + geom_args <- append(list(mapping=mapping), geom_args_extra) + # Returning a list is allowed because ggplot2:::ggplot_add.list will add the + # ggproto items sequentially + list( + do.call(ggplot2::geom_density, args=geom_args), + dont_expand_x_axis() + ) +} + +## geoms for mcmc_pairs off-diagonals #### + +#' @describeIn geom_mcmc_hist Scatter plot with settings for bayesplot by +#' default +geom_mcmc_scatter <- function(..., + mapping, np=NULL, + size=2.5, shape=21, alpha=0.8, + color=get_color("dh"), fill=get_color("d")) { + geom_point( + ..., + mapping = mapping, + shape = shape, + color = color, + fill = fill, + size = size, + alpha = alpha + ) +} + +#' @describeIn geom_mcmc_hist Hex plot with settings for bayesplot by +#' default +geom_mcmc_hex <- function(..., + mapping = ggplot2::aes_(fill = ~ scales::rescale(..density..)), + np=NULL, + binwidth=NULL) { + # Returning a list is allowed because ggplot2:::ggplot_add.list will add the + # ggproto items sequentially + list( + ggplot2::geom_hex( + mapping = mapping, + binwidth = binwidth, + ... + ), + ggplot2::scale_fill_gradientn( + "Density", + colors = unlist(color_scheme_get()), + breaks = c(.1, .9), + labels = c("low", "high") + ) + ) } diff --git a/man/geom_mcmc_hist.Rd b/man/geom_mcmc_hist.Rd new file mode 100644 index 00000000..5cdcb0d9 --- /dev/null +++ b/man/geom_mcmc_hist.Rd @@ -0,0 +1,71 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mcmc-scatterplots.R +\name{geom_mcmc_hist} +\alias{geom_mcmc_hist} +\alias{geom_mcmc_dens} +\alias{geom_mcmc_scatter} +\alias{geom_mcmc_hex} +\title{Specialized geoms for mcmc plotting with default bayesplot settings} +\usage{ +geom_mcmc_hist( + ..., + mapping = set_hist_aes(freq), + np = NULL, + fill = NULL, + color = NULL, + size = NULL, + binwidth = NULL, + bins = NULL, + breaks = NULL, + na.rm = TRUE, + freq = TRUE +) + +geom_mcmc_dens( + ..., + mapping = NULL, + np = NULL, + by_chain = FALSE, + color_chains = FALSE +) + +geom_mcmc_scatter( + ..., + mapping, + np = NULL, + size = 2.5, + shape = 21, + alpha = 0.8, + color = get_color("dh"), + fill = get_color("d") +) + +geom_mcmc_hex( + ..., + mapping = ggplot2::aes_(fill = ~scales::rescale(..density..)), + np = NULL, + binwidth = NULL +) +} +\arguments{ +\item{...}{Passed to the ggplot2 geom function} + +\item{fill, color, size, binwidth, breaks, na.rm}{Passed to the ggplot2 geom function} + +\item{freq}{Used to set the mapping when mapping is not manually set} +} +\description{ +Specialized geoms for mcmc plotting with default bayesplot settings +} +\section{Functions}{ +\itemize{ +\item \code{geom_mcmc_dens}: Density plot with settings for bayesplot by +default + +\item \code{geom_mcmc_scatter}: Scatter plot with settings for bayesplot by +default + +\item \code{geom_mcmc_hex}: Hex plot with settings for bayesplot by +default +}} + diff --git a/man/geom_mcmc_pairs_histogram.Rd b/man/geom_mcmc_pairs_histogram.Rd new file mode 100644 index 00000000..c246faca --- /dev/null +++ b/man/geom_mcmc_pairs_histogram.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mcmc-geoms.R +\name{geom_mcmc_pairs_histogram} +\alias{geom_mcmc_pairs_histogram} +\title{Specialized geoms for mcmc plotting with default bayesplot settings} +\usage{ +geom_mcmc_pairs_histogram( + ..., + mapping = set_hist_aes(freq), + fill = NULL, + color = NULL, + size = NULL, + binwidth = NULL, + breaks = NULL, + na.rm = TRUE, + freq = TRUE +) +} +\arguments{ +\item{...}{Passed to the ggplot2 geom function} + +\item{fill, color, size, binwidth, breaks, na.rm}{Passed to the ggplot2 geom function} + +\item{freq}{Used to set the mapping when mapping is not manually set} +} +\description{ +Specialized geoms for mcmc plotting with default bayesplot settings +} diff --git a/man/pairs_setfacet.Rd b/man/pairs_setfacet.Rd new file mode 100644 index 00000000..c9d951e2 --- /dev/null +++ b/man/pairs_setfacet.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mcmc-scatterplots.R +\name{pairs_setfacet} +\alias{pairs_setfacet} +\title{Generate a function to set the faceting parameters and row selection for data +for \code{mcmc_pairs()} geoms} +\usage{ +pairs_setfacet(fname1, fvalue1, fname2, fvalue2, select_rows) +} +\arguments{ +\item{fname1, fname2}{facet column names} + +\item{fvalue1, fvalue2}{facet values} + +\item{select_rows}{A logical vector of rows to keep (note, scalar \code{TRUE} will +keep all rows).} +} +\value{ +A function that takes in a data.frame and returns a data.frame with +columns \code{fname1} and \code{fname2} set to \code{fvalue1} and \code{fvalue2}. +} +\description{ +Generate a function to set the faceting parameters and row selection for data +for \code{mcmc_pairs()} geoms +} diff --git a/tests/testthat/test-mcmc-scatter-and-parcoord.R b/tests/testthat/test-mcmc-scatter-and-parcoord.R index 5c236ca3..e751dee1 100644 --- a/tests/testthat/test-mcmc-scatter-and-parcoord.R +++ b/tests/testthat/test-mcmc-scatter-and-parcoord.R @@ -61,50 +61,85 @@ test_that("mcmc_hex throws error if number of parameters is not 2", { # mcmc_pairs ------------------------------------------------------------- test_that("mcmc_pairs returns a bayesplot_grid object", { g <- mcmc_pairs(arr, pars = c("(Intercept)", "sigma")) - expect_bayesplot_grid(g) + expect_s3_class(g, "gg") expect_equal(print(g), plot(g)) - expect_bayesplot_grid(mcmc_pairs(arr, pars = "sigma", regex_pars = "beta")) - expect_bayesplot_grid(mcmc_pairs(arr, regex_pars = "x:[1-3]", - transformations = "exp", - diag_fun = "dens", off_diag_fun = "hex", - diag_args = list(trim = FALSE), - off_diag_args = list(binwidth = c(0.5, 0.5)))) - - expect_bayesplot_grid(suppressWarnings(mcmc_pairs(arr1chain, regex_pars = "beta"))) - expect_bayesplot_grid(suppressWarnings(mcmc_pairs(mat, pars = c("(Intercept)", "sigma")))) - expect_bayesplot_grid(suppressWarnings(mcmc_pairs(dframe, pars = c("(Intercept)", "sigma")))) - expect_bayesplot_grid(mcmc_pairs(dframe_multiple_chains, regex_pars = "beta")) + expect_s3_class( + mcmc_pairs(arr, pars = "sigma", regex_pars = "beta"), + "gg" + ) + expect_s3_class( + mcmc_pairs(arr, regex_pars = "x:[1-3]", + transformations = "exp", + diag_fun = "dens", off_diag_fun = "hex", + diag_args = list(trim = FALSE), + off_diag_args = list(binwidth = c(0.5, 0.5))), + "gg" + ) + + expect_s3_class( + suppressWarnings(mcmc_pairs(arr1chain, regex_pars = "beta")), + "gg" + ) + expect_s3_class( + suppressWarnings(mcmc_pairs(mat, pars = c("(Intercept)", "sigma"))), + "gg" + ) + expect_s3_class( + suppressWarnings(mcmc_pairs(dframe, pars = c("(Intercept)", "sigma"))), + "gg" + ) + expect_s3_class( + mcmc_pairs(dframe_multiple_chains, regex_pars = "beta"), + "gg" + ) }) test_that("no mcmc_pairs non-NUTS 'condition's fail", { - expect_bayesplot_grid( + expect_s3_class( mcmc_pairs(arr, pars = "sigma", regex_pars = "beta", - condition = pairs_condition(chains = list(1, 2:4))) - ) - expect_bayesplot_grid( + condition = pairs_condition(chains = list(1, 2:4))), + "gg" + ) + expect_s3_class( mcmc_pairs(arr, pars = "sigma", regex_pars = "beta", - condition = pairs_condition(draws = rep(c(T,F), length.out = prod(dim(arr)[1:2])))) - ) - expect_bayesplot_grid( + condition = pairs_condition(draws = rep(c(T,F), length.out = prod(dim(arr)[1:2])))), + "gg" + ) + expect_s3_class( mcmc_pairs(arr, pars = "sigma", regex_pars = "beta", - condition = pairs_condition(draws = 1/3)) + condition = pairs_condition(draws = 1/3), + ), + "gg" ) - expect_bayesplot_grid( + expect_s3_class( mcmc_pairs(arr, pars = "sigma", regex_pars = "beta", - condition = pairs_condition(chains = c(1,3))) + condition = pairs_condition(chains = c(1,3))), + "gg" ) }) test_that("mcmc_pairs works with NUTS info", { skip_if_not_installed("rstanarm") - expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am", "sigma"), np = np)) - expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"), - condition = pairs_condition(nuts="energy__"), np = np)) - expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"), - condition = pairs_condition(nuts="divergent__"), np = np)) - expect_bayesplot_grid(mcmc_pairs(post, pars = c("wt", "am"), - condition = pairs_condition(nuts = "lp__"), lp=lp, np = np, - max_treedepth = 2)) + expect_s3_class( + mcmc_pairs(post, pars = c("wt", "am", "sigma"), np = np), + "gg" + ) + expect_s3_class( + mcmc_pairs(post, pars = c("wt", "am"), + condition = pairs_condition(nuts="energy__"), np = np), + "gg" + ) + expect_s3_class( + mcmc_pairs(post, pars = c("wt", "am"), + condition = pairs_condition(nuts="divergent__"), np = np), + "gg" + ) + expect_s3_class( + mcmc_pairs(post, pars = c("wt", "am"), + condition = pairs_condition(nuts = "lp__"), lp=lp, np = np, + max_treedepth = 2), + "gg" + ) p <- mcmc_pairs( post, @@ -116,7 +151,7 @@ test_that("mcmc_pairs works with NUTS info", { np_style = pairs_style_np(div_color = "firebrick", td_color = "dodgerblue", div_size = 2, td_size = 2), max_treedepth = with(np, max(Value[Parameter == "treedepth__"]) - 1) ) - expect_bayesplot_grid(p) + expect_s3_class(p, "gg") }) @@ -301,8 +336,6 @@ test_that("pairs_condition message if multiple args specified", { ) }) - - # mcmc_parcoord ----------------------------------------------------------- test_that("mcmc_parcoord returns a ggplot object", { expect_gg(mcmc_parcoord(arr, pars = c("(Intercept)", "sigma")))