Skip to content

Commit

Permalink
implement lambda_control parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
gufengzhou committed Oct 11, 2021
1 parent ac719ba commit ef1b7aa
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
18 changes: 10 additions & 8 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,14 +410,16 @@ check_filedir <- function(plot_folder) {
return(plot_folder)
}

check_calibconstr <- function(calibration_constraint, iterations, trials) {
total_iters <- iterations * trials
if (calibration_constraint <0.01 | calibration_constraint > 0.1) {
calibration_constraint <- 0.1
message("calibration_constraint must be >=0.01 and <=0.1. Using default value 0.1")
} else if (total_iters * calibration_constraint < 500) {
warning("Calibration constraint set to be top ", calibration_constraint*100, "% calibrated models.",
" Only ", round(total_iters*calibration_constraint,0), " models left for pareto-optimal selection")
check_calibconstr <- function(calibration_constraint, iterations, trials, calibration_input) {
if (!is.null(calibration_input)) {
total_iters <- iterations * trials
if (calibration_constraint <0.01 | calibration_constraint > 0.1) {
calibration_constraint <- 0.1
message("calibration_constraint must be >=0.01 and <=0.1. Using default value 0.1")
} else if (total_iters * calibration_constraint < 500) {
warning("Calibration constraint set to be top ", calibration_constraint*100, "% calibrated models.",
" Only ", round(total_iters*calibration_constraint,0), " models left for pareto-optimal selection")
}
}
return(calibration_constraint)
}
17 changes: 12 additions & 5 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#' @param calibration_constraint Numeric. Default to 0.1 and allows 0.01-0.1. When
#' calibrating, 0.1 means top 10% calibrated models are used for pareto-optimal
#' selection. Lower \code{calibration_constraint} increases calibration accuracy.
#' @param lambda_control Numeric. From 0-1. Tunes ridge lambda between
#' lambda.min and lambda.1se
#' @param refresh Boolean. Set to \code{TRUE} when used in \code{robyn_refresh()}
#' @param ui Boolean. Save additional outputs for UI usage. List outcome.
#' @examples
Expand All @@ -41,6 +43,7 @@ robyn_run <- function(InputCollect,
pareto_fronts = 1,
plot_pareto = TRUE,
calibration_constraint = 0.1,
lambda_control = 1,
refresh = FALSE,
dt_hyper_fixed = NULL,
ui = FALSE) {
Expand Down Expand Up @@ -72,7 +75,7 @@ robyn_run <- function(InputCollect,
message("Rolling window moving forward: ", InputCollect$refresh_steps, " ", InputCollect$intervalType)
}

calibration_constraint <- check_calibconstr(calibration_constraint, InputCollect$iterations, InputCollect$trials)
calibration_constraint <- check_calibconstr(calibration_constraint, InputCollect$iterations, InputCollect$trials, InputCollect$calibration_input)

#####################################
#### Run robyn_mmm on set_trials
Expand Down Expand Up @@ -100,11 +103,11 @@ robyn_run <- function(InputCollect,
model_output_collect <- list()
model_output_collect[[1]] <- robyn_mmm(
hyper_collect = hyperparameters_fixed,
InputCollect = InputCollect
InputCollect = InputCollect,
# ,iterations = iterations
# ,cores = cores
# ,optimizer_name = InputCollect$nevergrad_algo
, lambda_fixed = dt_hyper_fixed$lambda
lambda_fixed = dt_hyper_fixed$lambda
)

model_output_collect[[1]]$trial <- 1
Expand Down Expand Up @@ -146,6 +149,7 @@ robyn_run <- function(InputCollect,
model_output <- robyn_mmm(
hyper_collect = InputCollect$hyperparameters,
InputCollect = InputCollect,
lambda_control = lambda_control,
refresh = refresh
)

Expand Down Expand Up @@ -800,6 +804,7 @@ robyn_mmm <- function(hyper_collect,
InputCollect,
iterations = InputCollect$iterations,
lambda.n = 100,
lambda_control = 1,
lambda_fixed = NULL,
refresh = FALSE) {
if (reticulate::py_module_available("nevergrad")) {
Expand Down Expand Up @@ -1121,13 +1126,15 @@ robyn_mmm <- function(hyper_collect,
) # plot(cvmod) coef(cvmod)
# head(predict(cvmod, newx=x_train, s="lambda.1se"))

lambda_range <- c(cvmod$lambda.min, cvmod$lambda.1se)
lambda <- lambda_range[1] + (lambda_range[2]-lambda_range[1]) * lambda_control

#####################################
#### refit ridge regression with selected lambda from x-validation

## if no lift calibration, refit using best lambda
if (hyper_fixed == FALSE) {
mod_out <- model_refit(x_train, y_train, lambda = cvmod$lambda.1se, lower.limits, upper.limits)
lambda <- cvmod$lambda.1se
mod_out <- model_refit(x_train, y_train, lambda = lambda, lower.limits, upper.limits)
} else {
mod_out <- model_refit(x_train, y_train, lambda = lambda_fixed[i], lower.limits, upper.limits)
lambda <- lambda_fixed[i]
Expand Down
4 changes: 4 additions & 0 deletions R/man/robyn_mmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions R/man/robyn_run.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ef1b7aa

Please sign in to comment.