Skip to content

Commit

Permalink
initial implementation of draws_dt without support for data.table sem…
Browse files Browse the repository at this point in the history
…antics
  • Loading branch information
mjskay committed Mar 8, 2024
1 parent d0deed3 commit 80daa13
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 3 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Suggests:
e1071 (>= 1.7-3),
dplyr,
tidyr,
data.table,
knitr,
ggplot2,
ggdist,
Expand Down
14 changes: 14 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

S3method("[",draws_array)
S3method("[",draws_df)
S3method("[",draws_dt)
S3method("[",draws_list)
S3method("[",draws_matrix)
S3method("[",draws_rvars)
Expand Down Expand Up @@ -57,11 +58,21 @@ S3method(as_draws_df,data.frame)
S3method(as_draws_df,default)
S3method(as_draws_df,draws_array)
S3method(as_draws_df,draws_df)
S3method(as_draws_df,draws_dt)
S3method(as_draws_df,draws_list)
S3method(as_draws_df,draws_matrix)
S3method(as_draws_df,draws_rvars)
S3method(as_draws_df,mcmc)
S3method(as_draws_df,mcmc.list)
S3method(as_draws_dt,data.frame)
S3method(as_draws_dt,default)
S3method(as_draws_dt,draws_array)
S3method(as_draws_dt,draws_df)
S3method(as_draws_dt,draws_dt)
S3method(as_draws_dt,draws_list)
S3method(as_draws_dt,draws_matrix)
S3method(as_draws_dt,draws_rvars)
S3method(as_draws_dt,mcmc)
S3method(as_draws_list,default)
S3method(as_draws_list,draws_array)
S3method(as_draws_list,draws_df)
Expand Down Expand Up @@ -404,6 +415,7 @@ export(Pr)
export(as_draws)
export(as_draws_array)
export(as_draws_df)
export(as_draws_dt)
export(as_draws_list)
export(as_draws_matrix)
export(as_draws_rvars)
Expand All @@ -426,6 +438,7 @@ export(dissent)
export(draw_ids)
export(draws_array)
export(draws_df)
export(draws_dt)
export(draws_list)
export(draws_matrix)
export(draws_of)
Expand All @@ -447,6 +460,7 @@ export(for_each_draw)
export(is_draws)
export(is_draws_array)
export(is_draws_df)
export(is_draws_dt)
export(is_draws_list)
export(is_draws_matrix)
export(is_draws_rvars)
Expand Down
2 changes: 2 additions & 0 deletions R/as_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ closest_draws_format <- function(x) {
out <- "matrix"
} else if (is_draws_array_like(x)) {
out <- "array"
} else if (is_draws_dt_like(x)) {
out <- "dt"
} else if (is_draws_df_like(x)) {
out <- "df"
} else if (is_draws_rvars_like(x)) {
Expand Down
14 changes: 11 additions & 3 deletions R/as_draws_df.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ as_draws_df.draws_df <- function(x, ...) {
x
}

#' @rdname draws_df
#' @export
as_draws_df.draws_dt <- function(x, ...) {
x <- as.data.frame(x)
class(x) <- class_draws_dt()
x
}

#' @rdname draws_df
#' @export
as_draws_df.draws_matrix <- function(x, ...) {
Expand Down Expand Up @@ -231,12 +239,12 @@ dplyr_reconstruct.draws_df <- function(data, template) {
data
}

# drop "draws_df" and "draws" classes if metadata columns were removed
# from the data frame
# drop "draws_dt", "draws_df", and "draws" classes if metadata columns were
# removed from the data frame
drop_draws_class_if_metadata_removed <- function(x, warn = TRUE) {
if (!all(reserved_df_variables() %in% names(x))) {
if (warn) warning_no_call("Dropping 'draws_df' class as required metadata was removed.")
class(x) <- setdiff(class(x), c("draws_df", "draws"))
class(x) <- setdiff(class(x), c("draws_dt", "draws_df", "draws"))
}
x
}
Expand Down
162 changes: 162 additions & 0 deletions R/as_draws_dt.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# set this to TRUE so that data.table:::cedta() will treat {posterior} as a
# data.table-aware package, then re-implement methods as necessary to use
# data.table-style indexing with `[`
.datatable.aware = FALSE

#' The `draws_dt` format
#'
#' @name draws_dt
#' @family formats
#'
#' @templateVar draws_format draws_dt
#' @templateVar base_class class(data.table::data.table())
#' @template draws_format-skeleton
#' @template args-format-nchains
#'
#' @details Objects of class `"draws_dt"` are [data.table][data.table::data.table]s.
#' They have one column per variable as well as additional metadata
#' columns `".iteration"`, `".chain"`, and `".draw"`. The difference between
#' the `".iteration"` and `".draw"` columns is that the former is relative to
#' the MCMC chain while the latter ignores the chain information and has all
#' unique values. See **Examples**.
#'
#' If a `data.table`-like object is supplied to `as_draws_dt` that contains
#' columns named `".iteration"` or `".chain"`, they will be treated as
#' iteration and chain indices, respectively. See **Examples**.
#'
#' @examples
#'
#' # the difference between iteration and draw is clearer when contrasting
#' # the head and tail of the data frame
#' print(head(x1), reserved = TRUE, max_variables = 2)
#' print(tail(x1), reserved = TRUE, max_variables = 2)
#'
#' # manually supply chain information
#' xnew <- data.table(mu = rnorm(10), .chain = rep(1:2, each = 5))
#' xnew <- as_draws_dt(xnew)
#' print(xnew)
#'
NULL


#' @rdname draws_dt
#' @export
as_draws_dt <- function(x, ...) {
UseMethod("as_draws_dt")
}

#' @rdname draws_dt
#' @export
as_draws_dt.default <- function(x, ...) {
x <- as_draws(x)
as_draws_dt(x, ...)
}

#' @rdname draws_dt
#' @export
as_draws_dt.data.frame <- function(x, ...) {
.as_draws_dt(x)
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_dt <- function(x, ...) {
x
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_df <- function(x, ...) {
class(x) <- class_draws_dt()
x
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_matrix <- function(x, ...) {
as_draws_dt(as_draws_df(x), ...)
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_array <- function(x, ...) {
as_draws_dt(as_draws_df(x), ...)
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_list <- function(x, ...) {
as_draws_dt(as_draws_df(x), ...)
}

#' @rdname draws_dt
#' @export
as_draws_dt.draws_rvars <- function(x, ...) {
as_draws_dt(as_draws_df(x), ...)
}

#' @rdname draws_dt
#' @export
as_draws_dt.mcmc <- function(x, ...) {
as_draws_dt(as_draws_matrix(x), ...)
}

#' @rdname draws_df
#' @export
as_draws_df.mcmc.list <- function(x, ...) {
as_draws_dt(as_draws_array(x), ...)
}

#' Convert any \R object into a `draws_dt` object
#' @param x An \R object.
#' @noRd
.as_draws_dt <- function(x) {
x <- .as_draws_df(x)
class(x) <- class_draws_dt()
x
}

#' @rdname draws_dt
#' @export
draws_dt <- function(..., .nchains = 1) {
as_draws_dt(draws_df(..., .nchains = .nchains))
}

class_draws_dt <- function() {
c("draws_dt", "draws_df", "draws", "data.table", "data.frame")
}

#' @rdname draws_dt
#' @export
is_draws_dt <- function(x) {
inherits(x, "draws_dt")
}

# is an object looking like a 'draws_dt' object?
is_draws_dt_like <- function(x) {
is.data.table(x)
}

#' @export
`[.draws_dt` <- function(x, ..., drop = FALSE, reserved = FALSE) {
reserved <- as_one_logical(reserved)
# data.table uses heuristics to pick if a symbol is evaluated in the calling
# context or in the data frame context; thus we have to reconstruct the
# calling expression (but without non-data.table arguments like `reserved`)
# and evaluate it in the calling context (data.table ignores `drop`)
subset_expr = substitute(data.table::`[.data.table`(x, ...))
out <- eval(subset_expr, envir = parent.frame())
if (reserved) {
reserved_vars <- all_reserved_variables(x)
reserved_vars <- setdiff(reserved_vars, names(out))
out[, reserved_vars] <- data.table::`[.data.table`(x, , reserved_vars)
} else {
out <- drop_draws_class_if_metadata_removed(out, warn = TRUE)
}
out
}

# create an empty draws_dt object
empty_draws_dt <- function(variables = character(0)) {
as_draws_dt(empty_draws_df(variables))
}

0 comments on commit 80daa13

Please sign in to comment.