Skip to content

Commit

Permalink
feat/fix: new intercept parameter for glmnet #722
Browse files Browse the repository at this point in the history
  • Loading branch information
laresbernardo committed May 8, 2023
1 parent 91cc279 commit ac1d864
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 12 deletions.
1 change: 0 additions & 1 deletion R/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ importFrom(foreach,getDoParWorkers)
importFrom(foreach,registerDoSEQ)
importFrom(ggridges,geom_density_ridges)
importFrom(ggridges,geom_density_ridges_gradient)
importFrom(glmnet,cv.glmnet)
importFrom(glmnet,glmnet)
importFrom(jsonlite,fromJSON)
importFrom(jsonlite,read_json)
Expand Down
2 changes: 1 addition & 1 deletion R/R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @importFrom foreach foreach %dopar% getDoParWorkers registerDoSEQ
#' @import ggplot2
#' @importFrom ggridges geom_density_ridges geom_density_ridges_gradient
#' @importFrom glmnet cv.glmnet glmnet
#' @importFrom glmnet glmnet
#' @importFrom jsonlite fromJSON toJSON write_json read_json
#' @importFrom lares check_opts clusterKmeans formatNum freqs glued num_abbr ohse removenacols
#' theme_lares `%>%` scale_x_abbr scale_x_percent scale_y_percent scale_y_abbr try_require v2t
Expand Down
26 changes: 17 additions & 9 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
#' \code{c("DE","TwoPointsDE", "OnePlusOne", "DoubleFastGADiscreteOnePlusOne",
#' "DiscreteOnePlusOne", "PortfolioDiscreteOnePlusOne", "NaiveTBPSA",
#' "cGA", "RandomSearch")}.
#' @param intercept Boolean. Should intercept(s) be fitted (default=TRUE) or
#' set to zero (FALSE).
#' @param intercept_sign Character. Choose one of "non_negative" (default) or
#' "unconstrained". By default, if intercept is negative, Robyn will drop intercept
#' and refit the model. Consider changing intercept_sign to "unconstrained" when
Expand Down Expand Up @@ -80,6 +82,7 @@ robyn_run <- function(InputCollect = NULL,
iterations = 2000,
rssd_zero_penalty = TRUE,
nevergrad_algo = "TwoPointsDE",
intercept = TRUE,
intercept_sign = "non_negative",
lambda_control = NULL,
...) {
Expand Down Expand Up @@ -151,7 +154,8 @@ robyn_run <- function(InputCollect = NULL,
OutputModels <- robyn_train(
InputCollect, hyper_collect,
cores = cores, iterations = iterations, trials = trials,
intercept_sign = intercept_sign, nevergrad_algo = nevergrad_algo,
intercept_sign = intercept_sign, intercept = intercept,
nevergrad_algo = nevergrad_algo,
dt_hyper_fixed = dt_hyper_fixed,
ts_validation = ts_validation,
add_penalty_factor = add_penalty_factor,
Expand All @@ -167,6 +171,7 @@ robyn_run <- function(InputCollect = NULL,
OutputModels$cores <- cores
OutputModels$iterations <- iterations
OutputModels$trials <- trials
OutputModels$intercept <- intercept
OutputModels$intercept_sign <- intercept_sign
OutputModels$nevergrad_algo <- nevergrad_algo
OutputModels$ts_validation <- ts_validation
Expand Down Expand Up @@ -227,6 +232,7 @@ print.robyn_models <- function(x, ...) {
{hypers}
Nevergrad Algo: {x$nevergrad_algo}
Intercept: {x$intercept}
Intercept sign: {x$intercept_sign}
Time-series validation: {x$ts_validation}
Penalty factor: {x$add_penalty_factor}
Expand Down Expand Up @@ -279,7 +285,8 @@ Pareto-front ({x$pareto_fronts}) All solutions ({nSols}): {paste(x$allSolutions,
#' @export
robyn_train <- function(InputCollect, hyper_collect,
cores, iterations, trials,
intercept_sign, nevergrad_algo,
intercept_sign, intercept,
nevergrad_algo,
dt_hyper_fixed = NULL,
ts_validation = TRUE,
add_penalty_factor = FALSE,
Expand All @@ -296,6 +303,7 @@ robyn_train <- function(InputCollect, hyper_collect,
iterations = iterations,
cores = cores,
nevergrad_algo = nevergrad_algo,
intercept = intercept,
intercept_sign = intercept_sign,
dt_hyper_fixed = dt_hyper_fixed,
ts_validation = ts_validation,
Expand Down Expand Up @@ -332,6 +340,7 @@ robyn_train <- function(InputCollect, hyper_collect,
iterations = iterations,
cores = cores,
nevergrad_algo = nevergrad_algo,
intercept = intercept,
intercept_sign = intercept_sign,
ts_validation = ts_validation,
add_penalty_factor = add_penalty_factor,
Expand Down Expand Up @@ -388,6 +397,7 @@ robyn_mmm <- function(InputCollect,
iterations,
cores,
nevergrad_algo,
intercept = TRUE,
intercept_sign,
ts_validation = TRUE,
add_penalty_factor = FALSE,
Expand All @@ -397,7 +407,7 @@ robyn_mmm <- function(InputCollect,
refresh = FALSE,
trial = 1L,
seed = 123L,
quiet = FALSE) {
quiet = FALSE, ...) {
if (reticulate::py_module_available("nevergrad")) {
ng <- reticulate::import("nevergrad", delay_load = TRUE)
if (is.integer(seed)) {
Expand Down Expand Up @@ -454,14 +464,9 @@ robyn_mmm <- function(InputCollect,
paid_media_signs <- InputCollect$paid_media_signs
prophet_signs <- InputCollect$prophet_signs
organic_signs <- InputCollect$organic_signs
all_media <- InputCollect$all_media
calibration_input <- InputCollect$calibration_input
optimizer_name <- nevergrad_algo
ts_validation <- ts_validation
add_penalty_factor <- add_penalty_factor
intercept_sign <- intercept_sign
i <- NULL # For parallel iterations (globalVar)
rssd_zero_penalty <- rssd_zero_penalty
}

################################################
Expand Down Expand Up @@ -675,6 +680,7 @@ robyn_mmm <- function(InputCollect,
lambda = lambda_scaled,
lower.limits = lower.limits,
upper.limits = upper.limits,
intercept = intercept,
intercept_sign = intercept_sign,
penalty.factor = penalty.factor,
...
Expand Down Expand Up @@ -1004,7 +1010,7 @@ model_decomp <- function(coefs, y_pred,
coefsOut$rn <- sapply(x_factor, function(x) str_replace(coefsOut$rn, paste0(x, ".*"), x))
}
rn_order <- names(xDecompOutAgg)
rn_order[rn_order=="intercept"] <- "(Intercept)"
rn_order[rn_order == "intercept"] <- "(Intercept)"
coefsOut <- coefsOut %>%
group_by(.data$rn) %>%
rename("coef" = 2) %>%
Expand Down Expand Up @@ -1034,6 +1040,7 @@ model_decomp <- function(coefs, y_pred,

model_refit <- function(x_train, y_train, x_val, y_val, x_test, y_test,
lambda, lower.limits, upper.limits,
intercept = TRUE,
intercept_sign = "non_negative",
penalty.factor = rep(1, ncol(y_train)),
...) {
Expand All @@ -1047,6 +1054,7 @@ model_refit <- function(x_train, y_train, x_val, y_val, x_test, y_test,
upper.limits = upper.limits,
type.measure = "mse",
penalty.factor = penalty.factor,
intercept = intercept,
...
) # coef(mod)

Expand Down
9 changes: 8 additions & 1 deletion R/man/robyn_mmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions R/man/robyn_run.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions R/man/robyn_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ac1d864

Please sign in to comment.