From 48f590e7d0279c305c2dbf9f42c7128f2fdc546f Mon Sep 17 00:00:00 2001 From: jgabry Date: Mon, 30 Oct 2023 14:07:48 -0600 Subject: [PATCH] avoid error for 1-D unit_vector the unit_vector isn't actually used anywhere in the situation when it errors (when K isn't >1) so we just make it size 2 in that case to avoid Stan's error for 1-D unit vectors. fixes #603 --- src/stan_files/lm.stan | 3 ++- src/stan_files/polr.stan | 4 +++- tests/testthat/test_stan_lm.R | 7 +++++++ tests/testthat/test_stan_polr.R | 5 +++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/stan_files/lm.stan b/src/stan_files/lm.stan index b2f5e10c..b929c525 100644 --- a/src/stan_files/lm.stan +++ b/src/stan_files/lm.stan @@ -55,7 +55,8 @@ transformed data { } parameters { // must not call with init="0" - array[K > 1 ? J : 0] unit_vector[K] u; // primitives for coefficients + // https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224 + array[K > 1 ? J : 0] unit_vector[K > 1 ? K : 2] u; // primitives for coefficients array[J * has_intercept] real z_alpha; // primitives for intercepts array[J] real 1 ? 0 : -1), upper=1> R2; // proportions of variance explained vector[J * (1 - prior_PD)] log_omega; // under/overfitting factors diff --git a/src/stan_files/polr.stan b/src/stan_files/polr.stan index 91a298df..e4268840 100644 --- a/src/stan_files/polr.stan +++ b/src/stan_files/polr.stan @@ -175,7 +175,9 @@ transformed data { } parameters { simplex[J] pi; - array[K > 1] unit_vector[K] u; + // avoid error by making unit_vector have 2 elements when K <= 1 + // https://github.com/stan-dev/rstanarm/issues/603#issuecomment-1785928224 + array[K > 1] unit_vector[K > 1 ? K : 2] u; real 1 ? 0 : -1), upper=1> R2; array[is_skewed] real alpha; } diff --git a/tests/testthat/test_stan_lm.R b/tests/testthat/test_stan_lm.R index 9d677ea3..49c95323 100644 --- a/tests/testthat/test_stan_lm.R +++ b/tests/testthat/test_stan_lm.R @@ -119,6 +119,13 @@ test_that("stan_lm doesn't break with vb algorithms", { expect_stanreg(fit2) }) +test_that("stan_lm works with 1 predictor", { + SW(fit <- stan_lm(mpg ~ wt, data = mtcars, + prior = R2(0.5, "mean"), refresh = 0, + seed = SEED)) + expect_stanreg(fit) +}) + test_that("stan_lm throws error if only intercept", { expect_error(stan_lm(mpg ~ 1, data = mtcars, prior = R2(location = 0.75)), regexp = "not suitable for estimating a mean") diff --git a/tests/testthat/test_stan_polr.R b/tests/testthat/test_stan_polr.R index 1f894793..89972bf9 100644 --- a/tests/testthat/test_stan_polr.R +++ b/tests/testthat/test_stan_polr.R @@ -56,6 +56,11 @@ test_that("stan_polr runs for esoph example", { expect_stanreg(fit2vb) }) +test_that("stan_polr runs with 1 predictor", { + esoph$x1 <- rnorm(nrow(esoph)) + expect_stanreg(stan_polr(tobgp ~ x1, data = esoph, prior = R2(0.5, "mean"))) +}) + test_that("stan_polr throws error if formula excludes intercept", { expect_error(stan_polr(tobgp ~ 0 + agegp + alcgp, data = esoph, method = "loglog", prior = R2(0.4, "median")),