Skip to content

Commit

Permalink
[BIG] feat: new allocator logic and outputs #600
Browse files Browse the repository at this point in the history
* Feat: all new one-pager for `robyn_allocator()` showing initial, bounded and less-bounded scenarios, using last month's worth of data by default. Relevant changes from previous versions: initial spend is now mean of date range selected, not non-zero mean anymore + deprecated "max_response_expected_spend" scenario + carryover information is now provided in the curves + inform user when budget is topped and can't be fully allocated + added mROAS / mCPA for better understanding of allocation.
* Feat: `robyn_response()` now requires date or date range for adstocking (last period by default) and accepts single or multiple values to return different use cases and scenarios. 
* Feat: new `transform_adstock()` exported wrapper function.
* Feat: added NRMSE validation on test set.
* Feat: added prophet monthly component.
* Fix: added correct solID for fixed hyperparameters (not 1_1_1).
* Recode: reduced the size of `xDecompVec` on `OutputCollect` to only pareto-front models.
* Recode: got rid of "ggcorrplot" and "rPref" package dependencies.
* Docs: added blueprint link to demo.R.

---------
Co-authored: @gufengzhou @laresbernardo
  • Loading branch information
laresbernardo authored Feb 28, 2023
1 parent 619765d commit d152ad3
Show file tree
Hide file tree
Showing 24 changed files with 1,423 additions and 815 deletions.
3 changes: 1 addition & 2 deletions R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.9.1.9000
Version: 3.10.0.9000
Authors@R: c(
person("Gufeng", "Zhou", , "[email protected]", c("aut")),
person("Leonel", "Sentana", , "[email protected]", c("aut")),
Expand All @@ -28,7 +28,6 @@ Imports:
patchwork,
prophet,
reticulate,
rPref,
stringr,
tidyr
Suggests:
Expand Down
4 changes: 2 additions & 2 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export(robyn_train)
export(robyn_update)
export(robyn_write)
export(saturation_hill)
export(transform_adstock)
export(ts_validation)
import(ggplot2)
importFrom(doParallel,registerDoParallel)
Expand Down Expand Up @@ -91,6 +92,7 @@ importFrom(lares,clusterKmeans)
importFrom(lares,formatNum)
importFrom(lares,freqs)
importFrom(lares,glued)
importFrom(lares,num_abbr)
importFrom(lares,ohse)
importFrom(lares,removenacols)
importFrom(lares,scale_x_abbr)
Expand All @@ -114,8 +116,6 @@ importFrom(prophet,add_regressor)
importFrom(prophet,add_seasonality)
importFrom(prophet,fit.prophet)
importFrom(prophet,prophet)
importFrom(rPref,low)
importFrom(rPref,psel)
importFrom(reticulate,conda_create)
importFrom(reticulate,conda_install)
importFrom(reticulate,import)
Expand Down
696 changes: 440 additions & 256 deletions R/R/allocator.R

Large diffs are not rendered by default.

12 changes: 4 additions & 8 deletions R/R/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,14 @@ robyn_calibrate <- function(calibration_input,
## 1. Adstock
if (adstock == "geometric") {
theta <- hypParamSam[paste0(get_channels[l_chn], "_thetas")][[1]][[1]]
x_list <- adstock_geometric(x = m_calib, theta = theta)
} else if (adstock == "weibull_cdf") {
shape <- hypParamSam[paste0(get_channels[l_chn], "_shapes")][[1]][[1]]
scale <- hypParamSam[paste0(get_channels[l_chn], "_scales")][[1]][[1]]
x_list <- adstock_weibull(x = m_calib, shape = shape, scale = scale, windlen = length(m), type = "cdf")
} else if (adstock == "weibull_pdf") {
}
if (grepl("weibull", adstock)) {
shape <- hypParamSam[paste0(get_channels[l_chn], "_shapes")][[1]][[1]]
scale <- hypParamSam[paste0(get_channels[l_chn], "_scales")][[1]][[1]]
x_list <- adstock_weibull(x = m_calib, shape = shape, scale = scale, windlen = length(m), type = "pdf")
}
m_calib_total_adst <- dt_modAdstocked[calib_pos, get_channels[l_chn]][[1]]
x_list <- transform_adstock(m_calib, adstock, theta = theta, shape = shape, scale = scale)
m_calib_imme_adst <- x_list$x_decayed
m_calib_total_adst <- dt_modAdstocked[calib_pos, get_channels[l_chn]][[1]]
m_calib_hist_adst <- m_calib_total_adst - m_calib_imme_adst
# Adapt for weibull_pdf with lags
m_calib_imme_adst[m_calib_hist_adst < 0] <- m_calib_total_adst[m_calib_hist_adst < 0]
Expand Down
172 changes: 140 additions & 32 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

############# Auxiliary non-exported functions #############

opts_pnd <- c("positive", "negative", "default")
other_hyps <- c("lambda", "train_size")
hyps_name <- c("thetas", "shapes", "scales", "alphas", "gammas")
OPTS_PDN <- c("positive", "negative", "default")
HYPS_NAMES <- c("thetas", "shapes", "scales", "alphas", "gammas")
HYPS_OTHERS <- c("lambda", "train_size")
LEGACY_PARAMS <- c("cores", "iterations", "trials", "intercept_sign", "nevergrad_algo")

check_nas <- function(df) {
name <- deparse(substitute(df))
Expand Down Expand Up @@ -172,8 +173,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
if (is.null(prophet_signs)) {
prophet_signs <- rep("default", length(prophet_vars))
}
if (!all(prophet_signs %in% opts_pnd)) {
stop("Allowed values for 'prophet_signs' are: ", paste(opts_pnd, collapse = ", "))
if (!all(prophet_signs %in% OPTS_PDN)) {
stop("Allowed values for 'prophet_signs' are: ", paste(OPTS_PDN, collapse = ", "))
}
if (length(prophet_signs) != length(prophet_vars)) {
stop("'prophet_signs' must have same length as 'prophet_vars'")
Expand All @@ -185,8 +186,8 @@ check_prophet <- function(dt_holidays, prophet_country, prophet_vars, prophet_si
check_context <- function(dt_input, context_vars, context_signs) {
if (!is.null(context_vars)) {
if (is.null(context_signs)) context_signs <- rep("default", length(context_vars))
if (!all(context_signs %in% opts_pnd)) {
stop("Allowed values for 'context_signs' are: ", paste(opts_pnd, collapse = ", "))
if (!all(context_signs %in% OPTS_PDN)) {
stop("Allowed values for 'context_signs' are: ", paste(OPTS_PDN, collapse = ", "))
}
if (length(context_signs) != length(context_vars)) {
stop("Input 'context_signs' must have same length as 'context_vars'")
Expand Down Expand Up @@ -235,8 +236,8 @@ check_paidmedia <- function(dt_input, paid_media_vars, paid_media_signs, paid_me
if (is.null(paid_media_signs)) {
paid_media_signs <- rep("positive", mediaVarCount)
}
if (!all(paid_media_signs %in% opts_pnd)) {
stop("Allowed values for 'paid_media_signs' are: ", paste(opts_pnd, collapse = ", "))
if (!all(paid_media_signs %in% OPTS_PDN)) {
stop("Allowed values for 'paid_media_signs' are: ", paste(OPTS_PDN, collapse = ", "))
}
if (length(paid_media_signs) == 1) {
paid_media_signs <- rep(paid_media_signs, length(paid_media_vars))
Expand Down Expand Up @@ -281,8 +282,8 @@ check_organicvars <- function(dt_input, organic_vars, organic_signs) {
organic_signs <- rep("positive", length(organic_vars))
# message("'organic_signs' were not provided. Using 'positive'")
}
if (!all(organic_signs %in% opts_pnd)) {
stop("Allowed values for 'organic_signs' are: ", paste(opts_pnd, collapse = ", "))
if (!all(organic_signs %in% OPTS_PDN)) {
stop("Allowed values for 'organic_signs' are: ", paste(OPTS_PDN, collapse = ", "))
}
if (length(organic_signs) != length(organic_vars)) {
stop("Input 'organic_signs' must have same length as 'organic_vars'")
Expand Down Expand Up @@ -444,10 +445,10 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
ref_hyp_name_spend <- hyper_names(adstock, all_media = paid_media_spends)
ref_hyp_name_expo <- hyper_names(adstock, all_media = exposure_vars)
ref_hyp_name_org <- hyper_names(adstock, all_media = organic_vars)
ref_hyp_name_other <- get_hyp_names[get_hyp_names %in% other_hyps]
# Excluding lambda (first other_hyps) given its range is not customizable
ref_all_media <- sort(c(ref_hyp_name_spend, ref_hyp_name_org, other_hyps))
all_ref_names <- c(ref_hyp_name_spend, ref_hyp_name_expo, ref_hyp_name_org, other_hyps)
ref_hyp_name_other <- get_hyp_names[get_hyp_names %in% HYPS_OTHERS]
# Excluding lambda (first HYPS_OTHERS) given its range is not customizable
ref_all_media <- sort(c(ref_hyp_name_spend, ref_hyp_name_org, HYPS_OTHERS))
all_ref_names <- c(ref_hyp_name_spend, ref_hyp_name_expo, ref_hyp_name_org, HYPS_OTHERS)
all_ref_names <- all_ref_names[order(all_ref_names)]
if (!all(get_hyp_names %in% all_ref_names)) {
wrong_hyp_names <- get_hyp_names[which(!(get_hyp_names %in% all_ref_names))]
Expand Down Expand Up @@ -717,7 +718,7 @@ check_hyper_fixed <- function(InputCollect, dt_hyper_fixed, add_penalty_factor)
# Adstock hyper-parameters
hypParamSamName <- hyper_names(adstock = InputCollect$adstock, all_media = InputCollect$all_media)
# Add lambda and other hyper-parameters manually
hypParamSamName <- c(hypParamSamName, other_hyps)
hypParamSamName <- c(hypParamSamName, HYPS_OTHERS)
# Add penalty factor hyper-parameters names
if (add_penalty_factor) {
for_penalty <- names(select(InputCollect$dt_mod, -.data$ds, -.data$dep_var))
Expand Down Expand Up @@ -774,8 +775,7 @@ check_class <- function(x, object) {
}

check_allocator <- function(OutputCollect, select_model, paid_media_spends, scenario,
channel_constr_low, channel_constr_up,
expected_spend, expected_spend_days, constr_mode) {
channel_constr_low, channel_constr_up, constr_mode) {
dt_hyppar <- OutputCollect$resultHypParam[OutputCollect$resultHypParam$solID == select_model, ]
if (!(select_model %in% OutputCollect$allSolutions)) {
stop(
Expand All @@ -792,11 +792,10 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
if (any(channel_constr_up > 5)) {
warning("Inputs 'channel_constr_up' > 5 might cause unrealistic allocation")
}
opts <- c("max_historical_response", "max_response_expected_spend")
opts <- "max_historical_response" # Deprecated: max_response_expected_spend
if (!(scenario %in% opts)) {
stop("Input 'scenario' must be one of: ", paste(opts, collapse = ", "))
}

if (length(channel_constr_low) != 1 && length(channel_constr_low) != length(paid_media_spends)) {
stop(paste(
"Input 'channel_constr_low' have to contain either only 1",
Expand All @@ -809,35 +808,144 @@ check_allocator <- function(OutputCollect, select_model, paid_media_spends, scen
"value or have same length as 'InputCollect$paid_media_spends':", length(paid_media_spends)
))
}

if ("max_response_expected_spend" %in% scenario) {
if (any(is.null(expected_spend), is.null(expected_spend_days))) {
stop("When scenario = 'max_response_expected_spend', expected_spend and expected_spend_days must be provided")
}
}
opts <- c("eq", "ineq")
if (!(constr_mode %in% opts)) {
stop("Input 'constr_mode' must be one of: ", paste(opts, collapse = ", "))
}
}

check_metric_value <- function(metric_value, media_metric) {
check_metric_type <- function(metric_name, paid_media_spends, paid_media_vars, exposure_vars, organic_vars) {
if (metric_name %in% paid_media_spends && length(metric_name) == 1) {
metric_type <- "spend"
} else if (metric_name %in% exposure_vars && length(metric_name) == 1) {
metric_type <- "exposure"
} else if (metric_name %in% organic_vars && length(metric_name) == 1) {
metric_type <- "organic"
} else {
stop(paste(
"Invalid 'metric_name' input. It must be any media variable from",
"paid_media_spends (spend), paid_media_vars (exposure),",
"or organic_vars (organic); NOT:", metric_name,
paste("\n- paid_media_spends:", v2t(paid_media_spends, quotes = FALSE)),
paste("\n- paid_media_vars:", v2t(paid_media_vars, quotes = FALSE)),
paste("\n- organic_vars:", v2t(organic_vars, quotes = FALSE))
))
}
return(metric_type)
}

check_metric_dates <- function(date_range = NULL, all_dates, dayInterval = NULL, quiet = FALSE, is_allocator = FALSE, ...) {
## default using latest 30 days / 4 weeks / 1 month for spend level
if (is.null(date_range)) {
if (is.null(dayInterval)) stop("Input 'date_range' or 'dayInterval' must be defined")
if (!is_allocator) {
date_range <- "last_1"
} else {
date_range <- paste0("last_", dplyr::case_when(
dayInterval == 1 ~ 30,
dayInterval == 7 ~ 4,
dayInterval >= 30 & dayInterval <= 31 ~ 1,
))
}
if (!quiet) message(sprintf("Automatically picked date_range = '%s'", date_range))
}
if (grepl("last|all", date_range[1])) {
## Using last_n as date_range range
if ("all" %in% date_range) date_range <- paste0("last_", length(all_dates))
get_n <- ifelse(grepl("_", date_range[1]), as.integer(gsub("last_", "", date_range)), 1)
date_range <- tail(all_dates, get_n)
date_range_loc <- which(all_dates %in% date_range)
date_range_updated <- all_dates[date_range_loc]
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
} else {
## Using dates as date_range range
if (all(is.Date(as.Date(date_range, origin = "1970-01-01")))) {
date_range <- as.Date(date_range, origin = "1970-01-01")
if (length(date_range) == 1) {
## Using only 1 date
if (all(date_range %in% all_dates)) {
date_range_updated <- date_range
date_range_loc <- which(all_dates == date_range)
if (!quiet) message("Using ds '", date_range_updated, "' as the response period")
} else {
date_range_loc <- which.min(abs(date_range - all_dates))
date_range_updated <- all_dates[date_range_loc]
if (!quiet) warning("Input 'date_range' (", date_range, ") has no match. Picking closest date: ", date_range_updated)
}
} else if (length(date_range) == 2) {
## Using two dates as "from-to" date range
date_range_loc <- unlist(lapply(date_range, function(x) which.min(abs(x - all_dates))))
date_range_loc <- date_range_loc[1]:date_range_loc[2]
date_range_updated <- all_dates[date_range_loc]
if (!quiet & !all(date_range %in% date_range_updated)) {
warning(paste(
"At least one date in 'date_range' input do not match any date.",
"Picking closest dates for range:", paste(range(date_range_updated), collapse = ":")
))
}
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
get_n <- length(date_range_loc)
} else {
## Manually inputting each date
date_range_updated <- date_range
if (all(date_range %in% all_dates)) {
date_range_loc <- which(all_dates %in% date_range_updated)
} else {
date_range_loc <- unlist(lapply(date_range_updated, function(x) which.min(abs(x - all_dates))))
rg <- v2t(range(date_range_updated), sep = ":", quotes = FALSE)
}
if (all(na.omit(date_range_loc - lag(date_range_loc)) == 1)) {
date_range_updated <- all_dates[date_range_loc]
if (!quiet) warning("At least one date in 'date_range' do not match ds. Picking closest date: ", date_range_updated)
} else {
stop("Input 'date_range' needs to have sequential dates")
}
}
} else {
stop("Input 'date_range' must have date format '2023-01-01' or use 'last_n'")
}
}
return(list(
date_range_updated = date_range_updated,
metric_loc = date_range_loc
))
}

check_metric_value <- function(metric_value, metric_name, all_values, metric_loc) {
get_n <- length(metric_loc)
if (any(is.nan(metric_value))) metric_value <- NULL
if (!is.null(metric_value)) {
if (!is.numeric(metric_value)) {
stop(sprintf(
"Input 'metric_value' for %s (%s) must be a numerical value\n", media_metric, toString(metric_value)
"Input 'metric_value' for %s (%s) must be a numerical value\n", metric_name, toString(metric_value)
))
}
if (sum(metric_value <= 0) > 0) {
if (any(metric_value < 0)) {
stop(sprintf(
"Input 'metric_value' for %s (%s) must be a positive value\n", media_metric, metric_value[metric_value <= 0]
"Input 'metric_value' for %s must be positive\n", metric_name
))
}
if (get_n > 1 & length(metric_value) == 1) {
metric_value_updated <- rep(metric_value / get_n, get_n)
# message(paste0("'metric_value'", metric_value, " splitting into ", get_n, " periods evenly"))
} else {
if (length(metric_value) != get_n) {
stop("robyn_response metric_value & date_range must have same length\n")
}
metric_value_updated <- metric_value
}
}
if (is.null(metric_value)) {
metric_value_updated <- all_values[metric_loc]
}
all_values_updated <- all_values
all_values_updated[metric_loc] <- metric_value_updated
return(list(
metric_value_updated = metric_value_updated,
all_values_updated = all_values_updated
))
}

LEGACY_PARAMS <- c("cores", "iterations", "trials", "intercept_sign", "nevergrad_algo")

check_legacy_input <- function(InputCollect,
cores = NULL, iterations = NULL, trials = NULL,
intercept_sign = NULL, nevergrad_algo = NULL) {
Expand Down
4 changes: 2 additions & 2 deletions R/R/exports.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ robyn_save <- function(InputCollect,
)

# Nice and tidy table format for hyper-parameters
regex <- paste(paste0("_", hyps_name), collapse = "|")
regex <- paste(paste0("_", HYPS_NAMES), collapse = "|")
hyps <- filter(OutputCollect$resultHypParam, .data$solID == select_model) %>%
select(contains(hyps_name)) %>%
select(contains(HYPS_NAMES)) %>%
tidyr::gather() %>%
tidyr::separate(.data$key,
into = c("channel", "none"),
Expand Down
3 changes: 1 addition & 2 deletions R/R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#' @importFrom ggridges geom_density_ridges geom_density_ridges_gradient
#' @importFrom glmnet cv.glmnet glmnet
#' @importFrom jsonlite fromJSON toJSON write_json read_json
#' @importFrom lares check_opts clusterKmeans formatNum freqs glued ohse removenacols
#' @importFrom lares check_opts clusterKmeans formatNum freqs glued num_abbr ohse removenacols
#' theme_lares `%>%` scale_x_abbr scale_x_percent scale_y_percent scale_y_abbr try_require v2t
#' @importFrom lubridate is.Date day floor_date
#' @importFrom minpack.lm nlsLM
Expand All @@ -38,7 +38,6 @@
#' @importFrom prophet add_regressor add_seasonality fit.prophet prophet
#' @importFrom reticulate tuple use_condaenv import conda_create conda_install py_module_available
#' virtualenv_create py_install use_virtualenv
#' @importFrom rPref low psel
#' @importFrom stats AIC BIC coef complete.cases dgamma dnorm end lm model.matrix na.omit
#' nls.control median qt sd predict pweibull dweibull quantile qunif reorder rnorm start setNames
#' @importFrom stringr str_count str_detect str_remove str_split str_which str_extract str_replace
Expand Down
10 changes: 5 additions & 5 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,12 @@ Adstock: {x$adstock}
hyper_names <- function(adstock, all_media) {
adstock <- check_adstock(adstock)
if (adstock == "geometric") {
local_name <- sort(apply(expand.grid(all_media, hyps_name[
grepl("thetas|alphas|gammas", hyps_name)
local_name <- sort(apply(expand.grid(all_media, HYPS_NAMES[
grepl("thetas|alphas|gammas", HYPS_NAMES)
]), 1, paste, collapse = "_"))
} else if (adstock %in% c("weibull_cdf", "weibull_pdf")) {
local_name <- sort(apply(expand.grid(all_media, hyps_name[
grepl("shapes|scales|alphas|gammas", hyps_name)
local_name <- sort(apply(expand.grid(all_media, HYPS_NAMES[
grepl("shapes|scales|alphas|gammas", HYPS_NAMES)
]), 1, paste, collapse = "_"))
}
return(local_name)
Expand Down Expand Up @@ -831,7 +831,7 @@ prophet_decomp <- function(dt_transform, dt_holidays,
)
}
mod <- fit.prophet(modelRecurrence, dt_regressors)
forecastRecurrence <- predict(mod, dt_regressors)
forecastRecurrence <- predict(mod, dt_regressors) # prophet::prophet_plot_components(modelRecurrence, forecastRecurrence)
}

these <- seq_along(unlist(recurrence[, 1]))
Expand Down
Loading

0 comments on commit d152ad3

Please sign in to comment.