Skip to content

Commit

Permalink
Update zi to covariate response
Browse files Browse the repository at this point in the history
Also creation of jsdmStanFamily class to support this, and associated changes to accessory functions and documentation
  • Loading branch information
fseaton committed Aug 21, 2024
1 parent f324279 commit 5159b70
Show file tree
Hide file tree
Showing 26 changed files with 682 additions and 121 deletions.
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ Suggests:
rmarkdown,
ggplot2
Config/testthat/edition: 3
Config/testthat/parallel: true
VignetteBuilder: knitr
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(plot,jsdmStanFit)
S3method(posterior_linpred,jsdmStanFit)
S3method(posterior_predict,jsdmStanFit)
S3method(pp_check,jsdmStanFit)
S3method(print,jsdmStanFamily)
S3method(print,jsdmStanFit)
S3method(print,jsdmprior)
S3method(print,jsdmstan_model)
Expand Down Expand Up @@ -40,6 +41,7 @@ export(nuts_params)
export(ordiplot)
export(posterior_linpred)
export(posterior_predict)
export(posterior_zipred)
export(pp_check)
export(rgampois)
export(rhat)
Expand Down
104 changes: 81 additions & 23 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#' grouping)
#' @param beta_param The parameterisation of the environmental covariate effects, by
#' default \code{"cor"}. See details for further information.
#' @param zi_param For the zero-inflated families, whether the zero-inflation parameter
#' is a species-specific constant (default, \code{"constant"}), or varies by
#' environmental covariates (\code{"covariate"}).
#'
#' @return A character vector of Stan code, class "jsdmstan_model"
#' @export
Expand All @@ -37,14 +40,15 @@
#'
jsdm_stancode <- function(method, family, prior = jsdm_prior(),
log_lik = TRUE, site_intercept = "none",
beta_param = "cor") {
beta_param = "cor", zi_param = "constant") {
# checks
family <- match.arg(family, c("gaussian", "bernoulli", "poisson",
"neg_binomial","binomial","zi_poisson",
"zi_neg_binomial"))
method <- match.arg(method, c("gllvm", "mglmm"))
beta_param <- match.arg(beta_param, c("cor","unstruct"))
site_intercept <- match.arg(site_intercept, c("none","grouped","ungrouped"))
zi_param <- match.arg(zi_param, c("constant","covariate"))
if (class(prior)[1] != "jsdmprior") {
stop("Prior must be given as a jsdmprior object")
}
Expand All @@ -53,15 +57,15 @@ jsdm_stancode <- function(method, family, prior = jsdm_prior(),
scode <- .modelcode(
method = method, family = family,
phylo = FALSE, prior = prior, log_lik = log_lik, site_intercept = site_intercept,
beta_param = beta_param
beta_param = beta_param, zi_param = zi_param
)
class(scode) <- c("jsdmstan_model", "character")
return(scode)
}


.modelcode <- function(method, family, phylo, prior, log_lik, site_intercept,
beta_param) {
beta_param, zi_param) {
model_functions <- "
"
data <- paste(
Expand Down Expand Up @@ -101,7 +105,13 @@ ifelse(site_intercept == "grouped",
int<lower=0> ss[Sum_nonzero]; //species index for Y_nz
int<lower=0> nn[Sum_nonzero]; //site index for Y_nz
int<lower=0> sz[Sum_zero]; //species index for Y_z
int<lower=0> nz[Sum_zero]; //site index for Y_z",""))
int<lower=0> nz[Sum_zero]; //site index for Y_z",""),
ifelse(grepl("zi_", family) & zi_param == "covariate","
int<lower=1> zi_k; //number of covariates for env effects on zi
matrix[N, zi_k] zi_X; //environmental covariate matrix for zi","")
)


transformed_data <- ifelse(method == "gllvm", "
// Ensures identifiability of the model - no rotation of factors
int<lower=1> M;
Expand Down Expand Up @@ -147,11 +157,18 @@ ifelse(site_intercept == "grouped",
"neg_binomial" = "
real<lower=0> kappa[S]; // neg_binomial parameters",
"poisson" = "",
"zi_poisson" = "
"zi_poisson" = switch(zi_param,
"constant" = "
real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
"covariate" = "
matrix[zi_k,S] zi_betas; //environmental effects for zi"),
"zi_neg_binomial" = switch(zi_param,
"constant" = "
real<lower=0> kappa[S]; // neg_binomial parameters
real<lower=0,upper=1> zi[S]; // zero-inflation parameter",
"zi_neg_binomial" = "
"covariate" = "
real<lower=0> kappa[S]; // neg_binomial parameters
real<lower=0,upper=1> zi[S]; // zero-inflation parameter"
matrix[zi_k,S] zi_betas; //environmental effects for zi")
)

pars <- paste(
Expand Down Expand Up @@ -235,22 +252,32 @@ ifelse(site_intercept == "grouped",
")
model <- paste("
matrix[N,S] mu;
", ifelse(grepl("zi_",family),"
", ifelse(grepl("zi_",family),paste0("
real mu_nz[Sum_nonzero];
real mu_z[Sum_zero];
int pos;
int neg;",""),
int neg;",switch(zi_param,"constant" = "",
"covariate" = "
real zi_nz[Sum_nonzero];
real zi_z[Sum_zero];")),""),
switch(method,
"gllvm" = gllvm_model,
"mglmm" = mglmm_model
),ifelse(grepl("zi_",family),"
),ifelse(grepl("zi_",family),paste0(ifelse(zi_param == "covariate", "
matrix[N,S] zi = zi_X * zi_betas;",""),"
for(i in 1:Sum_nonzero){
mu_nz[i] = mu[nn[i],ss[i]];
mu_nz[i] = mu[nn[i],ss[i]];",
switch(zi_param, "constant" = "",
"covariate" = "
zi_nz[i] = zi[nn[i],ss[i]];"),"
}
for(i in 1:Sum_zero){
mu_z[i] = mu[nz[i],sz[i]];
mu_z[i] = mu[nz[i],sz[i]];",
switch(zi_param, "constant" = "",
"covariate" = "
zi_z[i] = zi[nz[i],sz[i]];"),"
}
",""))
"),""))
model_priors <- paste(
ifelse(site_intercept %in% c("ungrouped","grouped"), paste("
// Site-level intercept priors
Expand Down Expand Up @@ -296,17 +323,24 @@ ifelse(site_intercept == "grouped",
"bern" = "",
"poisson" = "",
"binomial" = "",
"zi_poisson" = paste("
"zi_poisson" = switch(zi_param,"constant" = paste("
//zero-inflation parameter
zi ~ ", prior[["zi"]], ";
"),
"zi_neg_binomial" = paste("
"), "covariate" = paste("
//zero-inflation parameter
to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
")),
"zi_neg_binomial" = switch(zi_param, "constant" = paste("
//zero-inflation parameter
zi ~ ", prior[["zi"]], ";
kappa ~ ", prior[["kappa"]], ";
"), "covariate" = paste("
//zero-inflation parameter
to_vector(zi_betas) ~ ", prior[["zi_betas"]], ";
kappa ~ ", prior[["kappa"]], ";
")
)
)
))
model_pt2 <- if(!grepl("zi_", family)){ paste(
"
for(i in 1:N) Y[i,] ~ ",
Expand All @@ -323,13 +357,19 @@ ifelse(site_intercept == "grouped",
for(s in 1:S){
target
+= N_zero[s]
* log_sum_exp(log(zi[s]),
* log_sum_exp(",
switch(zi_param,"constant" = "log(zi[s]),
log1m(zi[s])
+",
"covariate" = "bernoulli_logit_lpmf(1 | segment(zi_z, neg, N_zero[s])),
bernoulli_logit_lpmf(0 | segment(zi_z, neg, N_zero[s]))
+"),
switch(family,
"zi_poisson" = "poisson_log_lpmf(0 | segment(mu_z, neg, N_zero[s])));",
"zi_neg_binomial" = "neg_binomial_2_log_lpmf(0 | segment(mu_z, neg, N_zero[s]), kappa[s]));"),"
target += N_nonzero[s] * log1m(zi[s]);
target += N_nonzero[s] * ",switch(zi_param,
"constant" = "log1m(zi[s]);",
"covariate" = "bernoulli_logit_lpmf(0 | segment(zi_nz, pos, N_nonzero[s]));"),"
target +=",
switch(family,
"zi_poisson" = "poisson_log_lpmf(segment(Y_nz,pos,N_nonzero[s]) |
Expand Down Expand Up @@ -362,7 +402,9 @@ ifelse(site_intercept == "grouped",
}", ""), ifelse(isTRUE(log_lik), paste(
"
{
matrix[N, S] linpred;", switch(site_intercept, "ungrouped" = paste("
matrix[N, S] linpred;",ifelse(grepl("zi", family) & zi_param == "covariate","
matrix[N,S] zi = zi_X * zi_betas;",""),
switch(site_intercept, "ungrouped" = paste("
linpred = rep_matrix(a_bar + a * sigma_a, S) + (X * betas) +",
switch(method,
"gllvm" = "((Lambda_uncor * sigma_L) * LV_uncor)'",
Expand Down Expand Up @@ -393,22 +435,38 @@ ifelse(site_intercept == "grouped",
"neg_binomial" = "log_lik[i, j] = neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);",
"poisson" = "log_lik[i, j] = poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
"binomial" = "log_lik[i, j] = binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);",
"zi_poisson" = "if (Y[i,j] == 0){
"zi_poisson" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
bernoulli_lpmf(0 |zi[j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}",
"zi_neg_binomial" = "if (Y[i,j] == 0){
"covariate" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
bernoulli_logit_lpmf(0 |zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}"),
"zi_neg_binomial" = switch(zi_param,"constant" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_lpmf(1 | zi[j]),
bernoulli_lpmf(0 |zi[j])
+ neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]));
} else {
log_lik[i, j] = bernoulli_lpmf(0 | zi[j])
+ neg_binomial_2_log_lpmf(Y[i,j] | linpred[i,j], kappa[j]);
}"
}",
"covariate" = "if (Y[i,j] == 0){
log_lik[i, j] = log_sum_exp(bernoulli_logit_lpmf(1 | zi[i,j]),
bernoulli_logit_lpmf(0 |zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]));
} else {
log_lik[i, j] = bernoulli_logit_lpmf(0 | zi[i,j])
+ poisson_log_lpmf(Y[i,j] | linpred[i,j]);
}")
),"
}
}
Expand Down
66 changes: 66 additions & 0 deletions R/jsdmstan-families.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#' jsdmStanFamily class
#'
#' This is the jsdmStanFamily class, which occupies a slot within any
#' jsdmStanFit object.
#'
#' @name jsdmStanFamily
#'
#' @section Elements for \code{jsdmStanFamily} objects:
#' \describe{
#' \item{\code{family}}{
#' A length one character vector describing family used to fit object. Options
#' are \code{"gaussian"}, \code{"poisson"}, \code{"bernoulli"},
#' \code{"neg_binomial"}, \code{"binomial"}, \code{"zi_poisson"},
#' \code{"zi_neg_binomial"}, or \code{"multiple"}.
#' }
#' \item{\code{params}}{
#' A character vector that includes all the names of the family-specific parameters.
#' }
#' \item{\code{params_dataresp}}{
#' A character vector that includes any named family-specific parameters that are
#' modelled in response to data.
#' }
#' \item{\code{preds}}{
#' A character vector of the measured predictors included if family parameters
#' are modelled in response to data. If family parameters are not modelled in
#' response to data this is left empty.
#' }
#' \item{\code{data_list}}{
#' A list containing the original data used to fit the model
#' (empty when save_data is set to \code{FALSE} or family parameters are not
#' modelled in response to data).
#' }
#' }
#'
jsdmStanFamily_empty <- function(){
res <- list(family = character(),
params = character(),
params_dataresp= character(),
preds = character(),
data_list = list())
class(res) <- "jsdmStanFamily"
return(res)
}

# jsdmStanFamily methods

#' Print jsdmStanFamily object
#'
#' @param x A jsdmStanFamily object
#' @param ... Other arguments, not used at this stage.
#'
#' @export
print.jsdmStanFamily <- function(x, ...){
cat(paste("Family:", x$family, "\n",
ifelse(length(x$params)>0,
paste("With parameters:",
paste0(x$params, sep = ", "),"\n"),
"")))
if(length(x$params_dataresp)>0){
cat(paste("Family-specific parameter",
paste0(x$params_dataresp,sep=", "),
"is modelled in response to", length(x$preds),
"predictors. These are named:",
paste0(x$preds, sep = ", ")))
}
}
5 changes: 3 additions & 2 deletions R/jsdmstanfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' A length one character vector describing type of jSDM
#' }
#' \item{\code{family}}{
#' A character vector describing response family
#' A jsdmStanFamily object describing characteristics of family
#' }
#' \item{\code{species}}{
#' A character vector of the species names
Expand All @@ -35,7 +35,7 @@
jsdmStanFit_empty <- function() {
res <- list(
jsdm_type = "None",
family = character(),
family = jsdmStanFamily_empty(),
species = character(),
sites = character(),
preds = character(),
Expand Down Expand Up @@ -77,6 +77,7 @@ print.jsdmStanFit <- function(x, ...) {
" Number of species: ", length(x$species), "\n",
" Number of sites: ", length(x$sites), "\n",
" Number of predictors: ", length(x$preds), "\n",
print(x$family),
"\n",
"Model run on ", length(x$fit@stan_args), " chains with ",
x$fit@stan_args[[1]]$iter, " iterations per chain (",
Expand Down
3 changes: 2 additions & 1 deletion R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#'
#' This function uses the \pkg{loo} package to compute PSIS-LOO CV, efficient
#' approximate leave-one-out (LOO) cross-validation for Bayesian models using Pareto
#' smoothed importance sampling (PSIS).
#' smoothed importance sampling (PSIS). This requires that the model was fit using
#' \code{log_lik = TRUE}.
#'
#' @param x The jsdmStanFit model object
#' @param ... Other arguments passed to the \code{\link[loo]{loo}} function
Expand Down
Loading

0 comments on commit 5159b70

Please sign in to comment.