Skip to content

Commit

Permalink
Merge 0ac09f9 into 3c7a1a9
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns authored Jan 26, 2024
2 parents 3c7a1a9 + 0ac09f9 commit eae8189
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 132 deletions.
16 changes: 9 additions & 7 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ CmdStanFit$set("public", name = "init", value = init)
#' @param seed (integer) The random seed to use when initializing the model.
#' @param verbose (logical) Whether to show verbose logging during compilation.
#' @param hessian (logical) Whether to expose the (experimental) hessian method.
#' @param force_recompile (logical) Whether to recompile cached model methods.
#'
#' @examples
#' \dontrun{
Expand All @@ -332,25 +333,26 @@ CmdStanFit$set("public", name = "init", value = init)
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
#' [hessian()]
#'
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE, force_recompile = FALSE) {
if (os_is_wsl()) {
stop("Additional model methods are not currently available with ",
"WSL CmdStan and will not be compiled",
call. = FALSE)
}
require_suggested_package("Rcpp")
if (length(private$model_methods_env_$hpp_code_) == 0) {
if (length(private$model_methods_env_$hpp_code_) == 0 && (
is.null(private$model_methods_env_$obj_file_) ||
!file.exists(private$model_methods_env_$obj_file_))) {
stop("Model methods cannot be used with a pre-compiled Stan executable, ",
"the model must be compiled again", call. = FALSE)
}
if (hessian) {
message("The hessian method relies on higher-order autodiff ",
"which is still experimental. Please report any compilation ",
"errors that you encounter")
warning("The hessian argument is deprecated and will be removed in a future release.\n",
"The hessian method is now exposed by default.")
}
message("Compiling additional model methods...")
if (is.null(private$model_methods_env_$model_ptr)) {
expose_model_methods(private$model_methods_env_, verbose, hessian)
expose_model_methods(private$model_methods_env_, verbose = verbose,
force_recompile = force_recompile)
}
initialize_model_pointer(private$model_methods_env_, self$data_file(), seed)
invisible(NULL)
Expand Down
2 changes: 2 additions & 0 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ build_cmdstan <- function(dir,
clean_cmdstan <- function(dir = cmdstan_path(),
cores = getOption("mc.cores", 2),
quiet = FALSE) {
unlink(file.path(dir, "model_methods.o"))
unlink(file.path(dir, "model_methods.cpp"))
withr::with_path(
c(
toolchain_PATH_env_var(),
Expand Down
18 changes: 11 additions & 7 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ CmdStanModel <- R6::R6Class(
#' @param compile_model_methods (logical) Compile additional model methods
#' (`log_prob()`, `grad_log_prob()`, `constrain_variables()`,
#' `unconstrain_variables()`).
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
#' be compiled with the model methods?
#' @param compile_hessian_method (logical) Deprecated and will be removed in a future release.
#' The hessian method is now compiled by default.
#' @param compile_standalone (logical) Should functions in the Stan model be
#' compiled for use in R? If `TRUE` the functions will be available via the
#' `functions` field in the compiled model object. This can also be done after
Expand Down Expand Up @@ -504,6 +504,10 @@ compile <- function(quiet = TRUE,
warning("'threads' is deprecated. Please use 'cpp_options = list(stan_threads = TRUE)' instead.")
cpp_options[["stan_threads"]] <- TRUE
}
if (isTRUE(compile_hessian_method)) {
warning("'compile_hessian_method' is deprecated. The hessian method is now compiled by default.")
compile_hessian_method <- FALSE
}

if (length(self$exe_file()) == 0) {
if (is.null(dir)) {
Expand Down Expand Up @@ -655,9 +659,10 @@ compile <- function(quiet = TRUE,
run_log <- wsl_compatible_run(
command = make_cmd(),
args = c(wsl_safe_path(tmp_exe),
cpp_options_to_compile_flags(cpp_options),
cpp_options_to_compile_flags(c(cpp_options, list("KEEP_OBJECT"="true"))),
stancflags_val),
wd = cmdstan_path(),
env = c("current", "CXXFLAGS" = "-fPIC"),
echo = !quiet || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
spinner = quiet && rlang::is_interactive() && !identical(Sys.getenv("IN_PKGDOWN"), "true"),
Expand Down Expand Up @@ -708,6 +713,7 @@ compile <- function(quiet = TRUE,
file.remove(exe)
}
file.copy(tmp_exe, exe, overwrite = TRUE)
private$model_methods_env_$obj_file_ <- paste0(temp_file_no_ext, ".o")
if (os_is_wsl()) {
res <- processx::run(
command = "wsl",
Expand All @@ -726,11 +732,9 @@ compile <- function(quiet = TRUE,
private$precompile_stanc_options_ <- NULL
private$precompile_include_paths_ <- NULL

if(!dry_run) {
if (!dry_run) {
if (compile_model_methods) {
expose_model_methods(env = private$model_methods_env_,
verbose = !quiet,
hessian = compile_hessian_method)
expose_model_methods(private$model_methods_env_, verbose = !quiet)
}
}
invisible(self)
Expand Down
117 changes: 94 additions & 23 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -728,45 +728,116 @@ get_cmdstan_flags <- function(flag_name) {
paste(flags, collapse = " ")
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags <- function(expr, model_methods = FALSE) {
cxxflags <- get_cmdstan_flags("CXXFLAGS")
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
cmdstanr_includes <- paste0("-I", shQuote(cmdstanr_includes))

r_includes <- paste(
paste0("-I", shQuote(system.file("include", package = "Rcpp", mustWork = TRUE))),
paste0("-I", shQuote(R.home(component = "include")))
)

libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = " ")
if (.Platform$OS.type == "windows") {
if (os_is_windows()) {
libs <- paste(libs, "-fopenmp")
}
lib_paths <- c("/stan/lib/stan_math/lib/tbb/",
"/stan/lib/stan_math/lib/sundials_6.1.1/lib/")
withr::with_path(paste0(cmdstan_path(), lib_paths),
withr::with_makevars(
c(
USE_CXX14 = 1,
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
PKG_LIBS = libs
),
Rcpp::sourceCpp(code = code, env = env, verbose = verbose)
new_makevars <- c(
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
PKG_CXXFLAGS = paste(cxxflags, cmdstanr_includes, r_includes, collapse = " "),
PKG_LIBS = libs
)
if (os_is_windows() && model_methods) {
new_makevars <- c(
new_makevars,
SHLIB_LD = paste0(rtools4x_toolchain_path(),"/gcc"),
LOCAL_CPPFLAGS = paste0("-I'",rtools4x_toolchain_path(),"/../include'"),
LOCAL_LIBS = paste0("-L'",rtools4x_toolchain_path(),"/../lib' -lstdc++"),
BINPREF = paste0(rtools4x_toolchain_path(), "/")
)
}
withr::with_path(
c(
paste0(cmdstan_path(), lib_paths),
toolchain_PATH_env_var()
),
withr::with_makevars(new_makevars, expr)
)
}

rcpp_source_stan <- function(code, env, verbose = FALSE) {
with_cmdstan_flags(Rcpp::sourceCpp(code = code, env = env, verbose = verbose))
invisible(NULL)
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))
initialize_method_functions <- function(env, so_name) {
env$model_ptr <-
function(...) { .Call("model_ptr_", ..., PACKAGE = so_name) }
env$log_prob <-
function(...) { .Call("log_prob_", ..., PACKAGE = so_name) }
env$grad_log_prob <-
function(...) { .Call("grad_log_prob_", ..., PACKAGE = so_name) }
env$hessian <-
function(...) { .Call("hessian_", ..., PACKAGE = so_name) }
env$get_num_upars <-
function(...) { .Call("get_num_upars_", ..., PACKAGE = so_name) }
env$get_param_metadata <-
function(...) { .Call("get_param_metadata_", ..., PACKAGE = so_name) }
env$unconstrain_variables <-
function(...) { .Call("unconstrain_variables_", ..., PACKAGE = so_name) }
env$constrain_variables <-
function(...) { .Call("constrain_variables_", ..., PACKAGE = so_name) }
env$unconstrained_param_names <-
function(...) { .Call("unconstrained_param_names_", ..., PACKAGE = so_name) }
env$constrained_param_names <-
function(...) { .Call("constrained_param_names_", ..., PACKAGE = so_name) }
}

if (hessian) {
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
expose_model_methods <- function(env, force_recompile = FALSE, verbose = FALSE) {
precomp_methods_file <- file.path(cmdstan_path(), "model_methods.o")
if (file.exists(precomp_methods_file) && force_recompile) {
unlink(precomp_methods_file)
}
model_methods_cpp <- system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)
source_file <- paste0(strip_ext(precomp_methods_file), ".cpp")
file.copy(model_methods_cpp, source_file, overwrite = FALSE)

model_obj_file <- env$obj_file_
if (!file.exists(model_obj_file)) {
if (rlang::is_interactive()) {
message("Model object file not found, recompiling model...")
}
temp_hpp_file <- tempfile()
writeLines(env$hpp_code_, con = paste0(temp_hpp_file, ".cpp"))
model_obj_file <- paste0(temp_hpp_file, ".o")
}

if (!file.exists(precomp_methods_file) && rlang::is_interactive()) {
message("Compiling and caching additional model methods...")
}
if (rlang::is_interactive()) {
message("Linking precompiled model methods to model object file...")
}

methods_dll <- tempfile(fileext = .Platform$dynlib.ext)
with_cmdstan_flags(
processx::run(
command = file.path(R.home(component = "bin"), "R"),
args = c("CMD", "SHLIB", repair_path(model_obj_file), repair_path(precomp_methods_file),
"-o", repair_path(methods_dll)),
echo = verbose || is_verbose_mode(),
echo_cmd = is_verbose_mode(),
error_on_status = FALSE
),
model_methods = TRUE
)

code <- paste(code, collapse = "\n")
rcpp_source_stan(code, env, verbose)
env$methods_dll_info <- with_cmdstan_flags(dyn.load(methods_dll, local = TRUE, now = TRUE))
initialize_method_functions(env, strip_ext(basename(methods_dll)))
invisible(NULL)
}

Expand Down
41 changes: 0 additions & 41 deletions inst/include/hessian.cpp

This file was deleted.

Loading

0 comments on commit eae8189

Please sign in to comment.