-
-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathkfold.R
484 lines (469 loc) · 19.2 KB
/
kfold.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
#' K-Fold Cross-Validation
#'
#' Perform exact K-fold cross-validation by refitting the model \eqn{K}
#' times each leaving out one-\eqn{K}th of the original data.
#' Folds can be run in parallel using the \pkg{future} package.
#'
#' @aliases kfold
#'
#' @inheritParams loo.brmsfit
#' @param K The number of subsets of equal (if possible) size
#' into which the data will be partitioned for performing
#' \eqn{K}-fold cross-validation. The model is refit \code{K} times, each time
#' leaving out one of the \code{K} subsets. If \code{K} is equal to the total
#' number of observations in the data then \eqn{K}-fold cross-validation is
#' equivalent to exact leave-one-out cross-validation.
#' @param Ksub Optional number of subsets (of those subsets defined by \code{K})
#' to be evaluated. If \code{NULL} (the default), \eqn{K}-fold cross-validation
#' will be performed on all subsets. If \code{Ksub} is a single integer,
#' \code{Ksub} subsets (out of all \code{K}) subsets will be randomly chosen.
#' If \code{Ksub} consists of multiple integers or a one-dimensional array
#' (created via \code{as.array}) potentially of length one, the corresponding
#' subsets will be used. This argument is primarily useful, if evaluation of
#' all subsets is infeasible for some reason.
#' @param folds Determines how the subsets are being constructed.
#' Possible values are \code{NULL} (the default), \code{"stratified"},
#' \code{"grouped"}, or \code{"loo"}. May also be a vector of length
#' equal to the number of observations in the data. Alters the way
#' \code{group} is handled. More information is provided in the 'Details'
#' section.
#' @param group Optional name of a grouping variable or factor in the model.
#' What exactly is done with this variable depends on argument \code{folds}.
#' More information is provided in the 'Details' section.
#' @param joint Indicates which observations' log likelihoods shall be
#' considered jointly in the ELPD computation. If \code{"obs"} or \code{FALSE}
#' (the default), each observation is considered separately. This enables
#' comparability of \code{kfold} with \code{loo}. If \code{"fold"} or
#' \code{TRUE}, the joint log likelihoods per fold are used. If
#' \code{"group"}, the joint log likelihoods per group within folds are used
#' (only available if argument \code{group} is specified).
#' @param save_fits If \code{TRUE}, a component \code{fits} is added to
#' the returned object to store the cross-validated \code{brmsfit}
#' objects and the indices of the omitted observations for each fold.
#' Defaults to \code{FALSE}.
#' @param recompile Logical, indicating whether the Stan model should be
#' recompiled. This may be necessary if you are running \code{reloo} on
#' another machine than the one used to fit the model.
#' @param future_args A list of further arguments passed to
#' \code{\link[future:future]{future}} for additional control over parallel
#' execution if activated.
#' @param ... Further arguments passed to \code{\link{brm}}.
#'
#' @return \code{kfold} returns an object that has a similar structure as the
#' objects returned by the \code{loo} and \code{waic} methods and
#' can be used with the same post-processing functions.
#'
#' @details The \code{kfold} function performs exact \eqn{K}-fold
#' cross-validation. First the data are partitioned into \eqn{K} folds
#' (i.e. subsets) of equal (or as close to equal as possible) size by default.
#' Then the model is refit \eqn{K} times, each time leaving out one of the
#' \code{K} subsets. If \eqn{K} is equal to the total number of observations
#' in the data then \eqn{K}-fold cross-validation is equivalent to exact
#' leave-one-out cross-validation (to which \code{loo} is an efficient
#' approximation). The \code{compare_ic} function is also compatible with
#' the objects returned by \code{kfold}.
#'
#' The subsets can be constructed in multiple different ways:
#' \itemize{
#' \item If both \code{folds} and \code{group} are \code{NULL}, the subsets
#' are randomly chosen so that they have equal (or as close to equal as
#' possible) size.
#' \item If \code{folds} is \code{NULL} but \code{group} is specified, the
#' data is split up into subsets, each time omitting all observations of one
#' of the factor levels, while ignoring argument \code{K}.
#' \item If \code{folds = "stratified"} the subsets are stratified after
#' \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_stratified}}.
#' \item If \code{folds = "grouped"} the subsets are split by
#' \code{group} using \code{\link[loo:kfold-helpers]{loo::kfold_split_grouped}}.
#' \item If \code{folds = "loo"} exact leave-one-out cross-validation
#' will be performed and \code{K} will be ignored. Further, if \code{group}
#' is specified, all observations corresponding to the factor level of the
#' currently predicted single value are omitted. Thus, in this case, the
#' predicted values are only a subset of the omitted ones.
#' \item If \code{folds} is a numeric vector, it must contain one element per
#' observation in the data. Each element of the vector is an integer in
#' \code{1:K} indicating to which of the \code{K} folds the corresponding
#' observation belongs. There are some convenience functions available in
#' the \pkg{loo} package that create integer vectors to use for this purpose
#' (see the Examples section below and also the
#' \link[loo:kfold-helpers]{kfold-helpers} page).
#' }
#'
#' When running \code{kfold} on a \code{brmsfit} created with the
#' \pkg{cmdstanr} backend in a different \R session, several recompilations
#' will be triggered because by default, \pkg{cmdstanr} writes the model
#' executable to a temporary directory. To avoid that, set option
#' \code{"cmdstanr_write_stan_file_dir"} to a nontemporary path of your choice
#' before creating the original \code{brmsfit} (see section 'Examples' below).
#'
#' @examples
#' \dontrun{
#' fit1 <- brm(count ~ zAge + zBase * Trt + (1|patient) + (1|obs),
#' data = epilepsy, family = poisson())
#' # throws warning about some pareto k estimates being too high
#' (loo1 <- loo(fit1))
#' # perform 10-fold cross validation
#' (kfold1 <- kfold(fit1, chains = 1))
#'
#' # use joint likelihoods per fold for ELPD evaluation
#' kfold(fit1, chains = 1, joint = "fold")
#'
#' # use the future package for parallelization of models
#' # that is to fit models belonging to different folds in parallel
#' library(future)
#' plan(multisession, workers = 4)
#' kfold(fit1, chains = 1)
#' plan(sequential)
#'
#' ## to avoid recompilations when running kfold() on a 'cmdstanr'-backend fit
#' ## in a fresh R session, set option 'cmdstanr_write_stan_file_dir' before
#' ## creating the initial 'brmsfit'
#' ## CAUTION: the following code creates some files in the current working
#' ## directory: two 'model_<hash>.stan' files, one 'model_<hash>(.exe)'
#' ## executable, and one 'fit_cmdstanr_<some_number>.rds' file
#' set.seed(7)
#' fname <- paste0("fit_cmdstanr_", sample.int(.Machine$integer.max, 1))
#' options(cmdstanr_write_stan_file_dir = getwd())
#' fit_cmdstanr <- brm(rate ~ conc + state, data = Puromycin,
#' backend = "cmdstanr", file = fname)
#'
#' # now restart the R session and run the following (after attaching 'brms')
#' set.seed(7)
#' fname <- paste0("fit_cmdstanr_", sample.int(.Machine$integer.max, 1))
#' fit_cmdstanr <- brm(rate ~ conc + state,
#' data = Puromycin,
#' backend = "cmdstanr",
#' file = fname)
#' kfold_cmdstanr <- kfold(fit_cmdstanr, K = 2)
#' }
#'
#' @seealso \code{\link{loo}}, \code{\link{reloo}}
#'
#' @importFrom loo kfold
#' @export kfold
#' @export
kfold.brmsfit <- function(x, ..., K = 10, Ksub = NULL, folds = NULL,
group = NULL, joint = FALSE, compare = TRUE,
resp = NULL, model_names = NULL, save_fits = FALSE,
recompile = NULL, future_args = list()) {
args <- split_dots(x, ..., model_names = model_names)
if (!"use_stored" %in% names(args)) {
further_arg_names <- c(
"K", "Ksub", "folds", "group", "joint", "resp", "save_fits"
)
args$use_stored <- all(names(args) %in% "models") &&
!any(further_arg_names %in% names(match.call()))
}
c(args) <- nlist(
criterion = "kfold", K, Ksub, folds, group, joint,
compare, resp, save_fits, recompile, future_args
)
do_call(compute_loolist, args)
}
# helper function to perform k-fold cross-validation
# @inheritParams kfold.brmsfit
# @param model_name ignored but included to avoid being passed to '...'
.kfold <- function(x, K, Ksub, folds, group, joint, save_fits,
newdata, resp, model_name, recompile = NULL,
future_args = list(), newdata2 = NULL, ...) {
stopifnot(is.brmsfit(x), is.list(future_args))
if (is.brmsfit_multiple(x)) {
warn_brmsfit_multiple(x)
class(x) <- "brmsfit"
}
if (is.null(newdata)) {
newdata <- x$data
} else {
newdata <- as.data.frame(newdata)
}
if (is.null(newdata2)) {
newdata2 <- x$data2
} else {
bterms <- brmsterms(x$formula)
newdata2 <- validate_data2(newdata2, bterms)
}
N <- nrow(newdata)
joint <- validate_joint(joint)
# validate argument 'group'
gvar <- NULL
if (!is.null(group)) {
valid_groups <- get_cat_vars(x)
if (length(group) != 1L || !group %in% valid_groups) {
stop2("Group '", group, "' is not a valid grouping factor. ",
"Valid groups are: \n", collapse_comma(valid_groups))
}
gvar <- factor(get(group, newdata))
}
# validate argument 'folds'
if (is.null(folds)) {
if (is.null(group)) {
fold_type <- "random"
folds <- loo::kfold_split_random(K, N)
} else {
fold_type <- "group"
folds <- as.numeric(gvar)
K <- length(levels(gvar))
message("Setting 'K' to the number of levels of '", group, "' (", K, ")")
}
} else if (is.character(folds) && length(folds) == 1L) {
opts <- c("loo", "stratified", "grouped")
fold_type <- match.arg(folds, opts)
req_group_opts <- c("stratified", "grouped")
if (fold_type %in% req_group_opts && is.null(group)) {
stop2("Argument 'group' is required for fold type '", fold_type, "'.")
}
if (fold_type == "loo") {
folds <- seq_len(N)
K <- N
message("Setting 'K' to the number of observations (", K, ")")
} else if (fold_type == "stratified") {
folds <- loo::kfold_split_stratified(K, gvar)
} else if (fold_type == "grouped") {
folds <- loo::kfold_split_grouped(K, gvar)
}
} else {
fold_type <- "custom"
folds <- as.numeric(factor(folds))
if (length(folds) != N) {
stop2("If 'folds' is a vector, it must be of length N.")
}
K <- max(folds)
message("Setting 'K' to the number of folds (", K, ")")
}
# validate argument 'Ksub'
if (is.null(Ksub)) {
Ksub <- seq_len(K)
} else {
# see issue #441 for reasons to check for arrays
is_array_Ksub <- is.array(Ksub)
Ksub <- as.integer(Ksub)
if (any(Ksub <= 0 | Ksub > K)) {
stop2("'Ksub' must contain positive integers not larger than 'K'.")
}
if (length(Ksub) == 1L && !is_array_Ksub) {
Ksub <- sample(seq_len(K), Ksub)
} else {
Ksub <- unique(Ksub)
}
Ksub <- sort(Ksub)
}
# ensure that the model can be run in the current R session
x <- recompile_model(x, recompile = recompile)
# split dots for use in log_lik and update
dots <- list(...)
ll_arg_names <- arg_names("log_lik")
ll_args <- dots[intersect(names(dots), ll_arg_names)]
ll_args$allow_new_levels <- TRUE
ll_args$sample_new_levels <-
first_not_null(ll_args$sample_new_levels, "gaussian")
ll_args$resp <- resp
ll_args$combine <- TRUE
up_args <- dots[setdiff(names(dots), ll_arg_names)]
up_args$object <- x
up_args$refresh <- 0
# function to be run inside future::future
.kfold_k <- function(k) {
message("Fitting model ", k, " out of ", K)
if (fold_type == "loo" && !is.null(group)) {
omitted <- which(folds == folds[k])
predicted <- k
} else {
omitted <- predicted <- which(folds == k)
}
newdata_omitted <- newdata[-omitted, , drop = FALSE]
up_args$newdata <- newdata_omitted
up_args$data2 <- subset_data2(newdata2, -omitted)
fit <- SW(do_call(update, up_args))
ll_args$object <- fit
ll_args$newdata <- newdata[predicted, , drop = FALSE]
ll_args$newdata2 <- subset_data2(newdata2, predicted)
lppds <- do_call(log_lik, ll_args)
if (joint == "fold") {
# compute the joint log score over all observations within a fold
lppds <- rowSums(lppds)
joint_obs <- 1
} else if (joint == "group") {
gvar_k <- gvar[predicted]
unique_gvar_k <- unique(gvar_k)
ngroups <- length(unique_gvar_k)
lppds_marg <- matrix(nrow = nrow(lppds), ncol = ngroups)
joint_obs <- rep(NA, length(predicted))
for (j in seq_len(ngroups)) {
sel_obs <- gvar_k == unique_gvar_k[j]
lppds_marg[, j] <- rowSums(lppds[, sel_obs, drop = FALSE])
# tells which observations' elpds were considered jointly
joint_obs[sel_obs] <- j
}
lppds <- lppds_marg
} else {
joint_obs <- seq_along(predicted)
}
out <- nlist(lppds, omitted, predicted, joint_obs)
if (save_fits) {
out$fit <- fit
}
return(out)
}
# TODO: separate parallel and non-parallel code to enable better printing?
future_args$X <- Ksub
future_args$FUN <- .kfold_k
future_args$future.seed <- TRUE
res <- do_call("future_lapply", future_args, pkg = "future.apply")
lppds <- pred_obs_list <- vector("list", length(Ksub))
if (save_fits) {
fits <- array(list(), dim = c(length(Ksub), 3))
dimnames(fits) <- list(NULL, c("fit", "omitted", "predicted"))
}
for (i in seq_along(Ksub)) {
if (save_fits) {
fits[i, ] <- res[[i]][c("fit", "omitted", "predicted")]
}
pred_obs_list[[i]] <- res[[i]]$predicted
lppds[[i]] <- res[[i]]$lppds
}
lppds <- do_call(cbind, lppds)
elpds <- apply(lppds, 2, log_mean_exp)
pred_obs <- unlist(pred_obs_list)
if (joint == "obs") {
# bring back elpds into the original observation order
elpds <- elpds[order(pred_obs)]
}
# compute effective number of parameters
ll_args$object <- x
ll_args$newdata <- newdata
ll_args$newdata2 <- newdata2
pred_obs_sorted <- sort(pred_obs)
if (length(Ksub) < K) {
# select the correct subset of predicted observations in the original order
ll_args$newdata <- ll_args$newdata[pred_obs_sorted, , drop = FALSE]
ll_args$newdata2 <- subset_data2(ll_args$newdata2, pred_obs_sorted)
}
ll_full <- do_call(log_lik, ll_args)
if (joint == "fold") {
# compute the joint log score over all observations within a fold
ll_full_marg <- matrix(nrow = nrow(ll_full), ncol = length(Ksub))
for (i in seq_along(Ksub)) {
sel_obs <- match(pred_obs_list[[i]], pred_obs_sorted)
ll_full_marg[, i] <- rowSums(ll_full[, sel_obs, drop = FALSE])
}
ll_full <- ll_full_marg
} else if (joint == "group") {
# compute the joint log score over all observations per group within a fold
ll_full_marg <- vector("list", length(Ksub))
for (i in seq_along(Ksub)) {
sel_obs <- match(pred_obs_list[[i]], pred_obs_sorted)
joint_obs <- res[[i]]$joint_obs
unique_joint_obs <- unique(joint_obs)
njoint <- length(unique_joint_obs)
ll_full_marg[[i]] <- matrix(nrow = nrow(ll_full), ncol = njoint)
for (j in seq_len(njoint)) {
sel_obs_j <- sel_obs[joint_obs == unique_joint_obs[j]]
ll_full_marg[[i]][, j] <- rowSums(ll_full[, sel_obs_j, drop = FALSE])
}
}
ll_full <- do_call(cbind, ll_full_marg)
}
lpds <- apply(ll_full, 2, log_mean_exp)
ps <- lpds - elpds
# put everything together in a loo object
pointwise <- cbind(elpd_kfold = elpds, p_kfold = ps, kfoldic = -2 * elpds)
est <- colSums(pointwise)
se_est <- sqrt(nrow(pointwise) * apply(pointwise, 2, var))
estimates <- cbind(Estimate = est, SE = se_est)
rownames(estimates) <- colnames(pointwise)
out <- nlist(estimates, pointwise)
atts <- nlist(K, Ksub, group, folds, fold_type, joint)
attributes(out)[names(atts)] <- atts
if (save_fits) {
out$fits <- fits
out$data <- newdata
out$data2 <- newdata2
}
structure(out, class = c("kfold", "loo"))
}
#' Predictions from K-Fold Cross-Validation
#'
#' Compute and evaluate predictions after performing K-fold
#' cross-validation via \code{\link{kfold}}.
#'
#' @param x Object of class \code{'kfold'} computed by \code{\link{kfold}}.
#' For \code{kfold_predict} to work, the fitted model objects need to have
#' been stored via argument \code{save_fits} of \code{\link{kfold}}.
#' @param method Method used to obtain predictions. Can be set to
#' \code{"posterior_predict"} (the default), \code{"posterior_epred"},
#' or \code{"posterior_linpred"}. For more details, see the respective
#' function documentations.
#' @inheritParams posterior_predict.brmsfit
#'
#' @return A \code{list} with two slots named \code{'y'} and \code{'yrep'}.
#' Slot \code{y} contains the vector of observed responses.
#' Slot \code{yrep} contains the matrix of predicted responses,
#' with rows being posterior draws and columns being observations.
#'
#' @seealso \code{\link{kfold}}
#'
#' @examples
#' \dontrun{
#' fit <- brm(count ~ zBase * Trt + (1|patient),
#' data = epilepsy, family = poisson())
#'
#' # perform k-fold cross validation
#' (kf <- kfold(fit, save_fits = TRUE, chains = 1))
#'
#' # define a loss function
#' rmse <- function(y, yrep) {
#' yrep_mean <- colMeans(yrep)
#' sqrt(mean((yrep_mean - y)^2))
#' }
#'
#' # predict responses and evaluate the loss
#' kfp <- kfold_predict(kf)
#' rmse(y = kfp$y, yrep = kfp$yrep)
#' }
#'
#' @export
kfold_predict <- function(x, method = "posterior_predict", resp = NULL, ...) {
if (!inherits(x, "kfold")) {
stop2("'x' must be a 'kfold' object.")
}
if (!all(c("fits", "data") %in% names(x))) {
stop2(
"Slots 'fits' and 'data' are required. ",
"Please run kfold with 'save_fits = TRUE'."
)
}
method <- get(validate_pp_method(method), mode = "function")
resp <- validate_resp(resp, x$fits[[1, "fit"]], multiple = FALSE)
all_predicted <- as.character(sort(unlist(x$fits[, "predicted"])))
npredicted <- length(all_predicted)
ndraws <- ndraws(x$fits[[1, "fit"]])
y <- rep(NA, npredicted)
yrep <- matrix(NA, nrow = ndraws, ncol = npredicted)
names(y) <- colnames(yrep) <- all_predicted
for (k in seq_rows(x$fits)) {
fit_k <- x$fits[[k, "fit"]]
predicted_k <- x$fits[[k, "predicted"]]
obs_names <- as.character(predicted_k)
newdata <- x$data[predicted_k, , drop = FALSE]
y[obs_names] <- get_y(fit_k, resp, newdata = newdata, ...)
yrep[, obs_names] <- method(
fit_k, newdata = newdata, resp = resp,
allow_new_levels = TRUE, summary = FALSE, ...
)
}
nlist(y, yrep)
}
# validate argument 'joint' in kfold
validate_joint <- function(joint) {
if (length(joint) != 1L) {
stop2("Argument 'joint' must be of length 1.")
}
if (is.logical(joint)) {
# for backwards compatibility with brms < 2.20.18
joint <- as_one_logical(joint)
joint <- str_if(joint, "fold", "obs")
}
joint <- as_one_character(joint)
options <- c("obs", "fold", "group")
match.arg(joint, options)
}