Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bl01: refresh fixes and improvements + others #969

Merged
merged 29 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
21672d8
fix: non-existing folder when refreshing
laresbernardo Apr 23, 2024
9a2c730
fix: missing tail / case for refresh files
laresbernardo Apr 23, 2024
dda8be9
feat: display refresh plot_folder in msg
laresbernardo Apr 23, 2024
622162b
fix: when refresh's cluster = FALSE issue
laresbernardo Apr 24, 2024
6a3bf8f
fix: when refresh's cluster = FALSE issue p2
laresbernardo Apr 24, 2024
8766dd7
test: clean Windows parallel code
laresbernardo Apr 25, 2024
eb78604
fix: changed OutputModels for OutputCollect
laresbernardo Apr 25, 2024
919bbe9
fix: CSV exports in refresh into correct dir
laresbernardo Apr 26, 2024
d252591
docs: preparing for CRAN version v3.11.0
laresbernardo Apr 29, 2024
5158e1a
fix: ts_validation on refresh #960
laresbernardo Apr 30, 2024
c0fba7d
fix: export refresh selected model, deal with non-existing json plot_…
laresbernardo Apr 30, 2024
b0e4b29
fix: recreate only InputCollect when no model available (RobynModel-m…
laresbernardo May 2, 2024
fa18311
fix: correct check for penalties hyps when recreating a model #960
laresbernardo May 3, 2024
b539c02
fix: check for penalties only when penalties hyps are present
laresbernardo May 3, 2024
7de5b66
fix: check for penalties only when penalties hyps are present p2
laresbernardo May 3, 2024
950a904
fix: be more specific on Robyn + init/rf folders for chains
laresbernardo May 3, 2024
241dee1
fix: window_end for refresh should be inherited from json, not the data
laresbernardo May 3, 2024
ddc07d5
fix: warning - row names were found from a short variable and have be…
laresbernardo May 3, 2024
e875086
fix: scale of pBarRF improved + cap when no chain found
laresbernardo May 3, 2024
78bc094
fix: scale of pBarRF improved + cap when no chain found p2
laresbernardo May 3, 2024
a4eefd3
feat: enable listInit param on refresh plots json
laresbernardo May 3, 2024
4bf3256
feat: new add_data robyn_write() param + export original model when r…
laresbernardo May 4, 2024
5ef815a
fix: avoid passing data twice
laresbernardo May 4, 2024
4692b5b
fix: improved chain logic
laresbernardo May 4, 2024
a8199a6
fix: improved chain logic p2
laresbernardo May 4, 2024
ba4515c
fix: improved chain logic p3
laresbernardo May 4, 2024
c4a0f53
v3.10.7
laresbernardo May 6, 2024
4b0d2e2
Merge branch 'main' into bl01
laresbernardo May 7, 2024
6130be8
fix: export report_decomposition.png in plot_folder + length(ids) <> …
laresbernardo May 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.10.6.9006
Version: 3.10.7
Authors@R: c(
person("Gufeng", "Zhou", , "[email protected]", c("cre","aut")),
person("Bernardo", "Lares", , "[email protected]", c("aut")),
Expand Down
24 changes: 14 additions & 10 deletions R/R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ check_adstock <- function(adstock) {

check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
paid_media_spends = NULL, organic_vars = NULL,
exposure_vars = NULL) {
exposure_vars = NULL, prophet_vars = NULL,
contextual_vars = NULL) {
if (is.null(hyperparameters)) {
message(paste(
"Input 'hyperparameters' not provided yet. To include them, run",
Expand All @@ -495,16 +496,23 @@ check_hyperparameters <- function(hyperparameters = NULL, adstock = NULL,
ref_all_media <- sort(c(ref_hyp_name_spend, ref_hyp_name_org, HYPS_OTHERS))
all_ref_names <- c(ref_hyp_name_spend, ref_hyp_name_expo, ref_hyp_name_org, HYPS_OTHERS)
all_ref_names <- all_ref_names[order(all_ref_names)]
rm_penalty <- !grepl("_penalty$", get_hyp_names)
if (!all(get_hyp_names[rm_penalty] %in% all_ref_names)) {
# Adding penalty variations to the dictionary
if (any(grepl("_penalty", paste0(get_hyp_names)))) {
ref_hyp_name_penalties <- paste0(
c(paid_media_spends, organic_vars, prophet_vars, contextual_vars), "_penalty")
all_ref_names <- c(all_ref_names, ref_hyp_name_penalties)
} else {
ref_hyp_name_penalties <- NULL
}
if (!all(get_hyp_names %in% all_ref_names)) {
wrong_hyp_names <- get_hyp_names[which(!(get_hyp_names %in% all_ref_names))]
stop(
"Input 'hyperparameters' contains following wrong names: ",
paste(wrong_hyp_names, collapse = ", ")
)
}
total <- length(get_hyp_names[rm_penalty])
total_in <- length(c(ref_hyp_name_spend, ref_hyp_name_org, ref_hyp_name_other))
total <- length(get_hyp_names)
total_in <- length(c(ref_hyp_name_spend, ref_hyp_name_org, ref_hyp_name_penalties, ref_hyp_name_other))
if (total != total_in) {
stop(sprintf(
paste(
Expand Down Expand Up @@ -828,11 +836,7 @@ check_init_msg <- function(InputCollect, cores) {
if (cores == 1) {
message(paste(base, "with no parallel computation"))
} else {
if (check_parallel()) {
message(paste(base, "on", cores, "cores"))
} else {
message(paste(base, "on 1 core (Windows fallback)"))
}
message(paste(base, "on", cores, "cores"))
}
}

Expand Down
3 changes: 2 additions & 1 deletion R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ robyn_inputs <- function(dt_input = NULL,

## Check hyperparameters
hyperparameters <- check_hyperparameters(
hyperparameters, adstock, paid_media_spends, organic_vars, exposure_vars
hyperparameters, adstock, paid_media_spends, organic_vars,
exposure_vars, prophet_vars, context_vars
)
InputCollect <- robyn_engineering(InputCollect, ...)
}
Expand Down
67 changes: 47 additions & 20 deletions R/R/json.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#' @param InputCollect \code{robyn_inputs()} output.
#' @param select_model Character. Which model ID do you want to export
#' into the JSON file?
#' @param add_data Boolean. Include raw dataset. Useful to recreate models
#' with a single file containing all the required information (no need of CSV).
#' @param dir Character. Existing directory to export JSON file to.
#' @param pareto_df Dataframe. Save all pareto solutions to json file.
#' @param ... Additional parameters to export into a custom Extras element.
#' If you wish to include the data to recreate a model, add
#' \code{raw_data = InputCollect$dt_input}.
#' @examples
#' \dontrun{
#' InputCollectJSON <- robyn_inputs(
Expand All @@ -37,6 +37,7 @@ robyn_write <- function(InputCollect,
OutputCollect = NULL,
select_model = NULL,
dir = OutputCollect$plot_folder,
add_data = TRUE,
export = TRUE,
quiet = FALSE,
pareto_df = NULL,
Expand Down Expand Up @@ -133,8 +134,12 @@ robyn_write <- function(InputCollect,
select_model <- "inputs"
}

if (length(list(...)) > 0) {
ret[["Extras"]] <- list(...)
extras <- list(...)
if (isTRUE(add_data) & !"raw_data" %in% names(extras)) {
extras[["raw_data"]] <- InputCollect$dt_input
}
if (length(extras) > 0) {
ret[["Extras"]] <- extras
}

if (!dir.exists(dir) & export) dir.create(dir, recursive = TRUE)
Expand All @@ -143,7 +148,7 @@ robyn_write <- function(InputCollect,
class(ret) <- c("robyn_write", class(ret))
attr(ret, "json_file") <- filename
if (export) {
if (!quiet) message(sprintf(">> Exported model %s as %s", select_model, filename))
if (!quiet) message(sprintf(">> Exported %s as %s", select_model, filename))
if (!is.null(pareto_df)) {
if (!all(c("solID", "cluster") %in% names(pareto_df))) {
warning(paste(
Expand Down Expand Up @@ -190,9 +195,8 @@ print.robyn_write <- function(x, ...) {
x$ExportedModel$performance$metric, signif(x$ExportedModel$performance$performance, 4)), ""),
errors = paste(
sprintf(
"Adj.R2 (%s): %s",
ifelse(!val, "train", "test"),
ifelse(!val, signif(errors$rsq_train, 4), signif(errors$rsq_test, 4))
"Adj.R2 (train): %s",
signif(errors$rsq_train, 4)
),
"| NRMSE =", signif(errors$nrmse, 4),
"| DECOMP.RSSD =", signif(errors$decomp.rssd, 4),
Expand Down Expand Up @@ -322,21 +326,23 @@ Adstock: {a$adstock}
#' @export
robyn_recreate <- function(json_file, quiet = FALSE, ...) {
json <- robyn_read(json_file, quiet = TRUE)
message(">>> Recreating model ", json$ExportedModel$select_model)
message(">>> Recreating ", json$ExportedModel$select_model)
args <- list(...)
if (!"InputCollect" %in% names(args)) {
InputCollect <- robyn_inputs(
json_file = json_file,
quiet = quiet,
...
)
OutputCollect <- robyn_run(
InputCollect = InputCollect,
json_file = json_file,
export = FALSE,
quiet = quiet,
...
)
if (!is.null(json$ExportedModel$select_model)) {
OutputCollect <- robyn_run(
InputCollect = InputCollect,
json_file = json_file,
export = FALSE,
quiet = quiet,
...
)
} else OutputCollect <- NULL
} else {
# Use case: skip feature engineering when InputCollect is provided
InputCollect <- args[["InputCollect"]]
Expand All @@ -360,22 +366,43 @@ robyn_chain <- function(json_file) {
ids <- c(json_data$InputCollect$refreshChain, json_data$ExportedModel$select_model)
plot_folder <- json_data$ExportedModel$plot_folder
temp <- str_split(plot_folder, "/")[[1]]
chain <- temp[startsWith(temp, "Robyn_")]
chain <- temp[startsWith(temp, "Robyn_") & grepl("_init+$|_rf[0-9]+$", temp)]
if (length(chain) == 0) chain <- tail(temp[temp != ""], 1)
avlb <- NULL
if (length(ids) != length(chain)) {
temp <- list.files(plot_folder)
mods <- unique(temp[
(startsWith(temp, "RobynModel") | grepl("\\.json+$", temp)) &
grepl("^[^_]*_[^_]*_[^_]*$", temp)])
avlb <- gsub("RobynModel-|\\.json", "", mods)
if (length(ids) == length(mods)) {
chain <- rep_len(chain, length(mods))
}
}
base_dir <- gsub(sprintf("\\/%s.*", chain[1]), "", plot_folder)
chainData <- list()
for (i in rev(seq_along(chain))) {
if (i == length(chain)) {
for (i in rev(seq_along(ids))) {
if (i == length(ids)) {
json_new <- json_data
} else {
file <- paste0("RobynModel-", json_new$InputCollect$refreshSourceID, ".json")
filename <- paste(c(base_dir, chain[1:i], file), collapse = "/")
json_new <- robyn_read(filename, quiet = TRUE)
if (file.exists(filename)) {
json_new <- robyn_read(filename, quiet = TRUE)
} else {
if (ids[i] %in% avlb) {
filename <- mods[avlb == ids[i]]
json_new <- robyn_read(filename, quiet = TRUE)
} else {
message("Skipping chain. File can't be found: ", filename)
}
}
}
chainData[[json_new$ExportedModel$select_model]] <- json_new
}
chainData <- chainData[rev(seq_along(chain))]
dirs <- unlist(lapply(chainData, function(x) x$ExportedModel$plot_folder))
dirs[!dir.exists(dirs)] <- plot_folder
json_files <- paste0(dirs, "RobynModel-", names(dirs), ".json")
attr(chainData, "json_files") <- json_files
attr(chainData, "chain") <- ids # names(chainData)
Expand Down
16 changes: 11 additions & 5 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,12 @@ robyn_run <- function(InputCollect = NULL,
# Direct output & not all fixed hyperparameters, including refresh mode
output <- robyn_outputs(InputCollect, OutputModels, refresh = refresh, ...)
} else {
# Direct output & all fixed hyperparameters, thus no cluster
output <- robyn_outputs(InputCollect, OutputModels, clusters = FALSE, ...)
if (!"clusters" %in% names(list(...))) {
# Direct output & all fixed hyperparameters, thus no cluster
output <- robyn_outputs(InputCollect, OutputModels, clusters = FALSE, ...)
} else {
output <- robyn_outputs(InputCollect, OutputModels, ...)
}
}

# Created with assign from JSON file
Expand Down Expand Up @@ -950,9 +954,11 @@ robyn_mmm <- function(InputCollect,
)

# stop cluster to avoid memory leaks
stopImplicitCluster()
registerDoSEQ()
getDoParWorkers()
if (cores > 1) {
stopImplicitCluster()
registerDoSEQ()
getDoParWorkers()
}

if (!hyper_fixed) {
cat("\r", paste("\n Finished in", round(sysTimeDopar[3] / 60, 2), "mins"))
Expand Down
10 changes: 5 additions & 5 deletions R/R/pareto.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ robyn_pareto <- function(InputCollect, OutputModels,

# Prepare parallel loop
if (TRUE) {
if (check_parallel() & OutputModels$cores > 1) registerDoParallel(OutputModels$cores) else registerDoSEQ()
if (OutputModels$cores > 1) {
registerDoParallel(OutputModels$cores)
registerDoSEQ()
}
if (hyper_fixed) pareto_fronts <- 1
# Get at least 100 candidates for better clustering
if (nrow(resultHypParam) == 1) pareto_fronts <- 1
Expand Down Expand Up @@ -213,8 +216,6 @@ robyn_pareto <- function(InputCollect, OutputModels,
run_dt_resp(respN, InputCollect, OutputModels, decompSpendDistPar, resultHypParamPar, xDecompAggPar, ...)
}
stopImplicitCluster()
registerDoSEQ()
getDoParWorkers()
} else {
resp_collect <- bind_rows(lapply(seq_along(decompSpendDistPar$rn), function(respN) {
run_dt_resp(respN, InputCollect, OutputModels, decompSpendDistPar, resultHypParamPar, xDecompAggPar, ...)
Expand Down Expand Up @@ -596,8 +597,7 @@ robyn_pareto <- function(InputCollect, OutputModels,
df_caov_pct_all = df_caov_pct_all
)

# if (check_parallel()) stopImplicitCluster()
# close(pbplot)
if (OutputModels$cores > 1) stopImplicitCluster()

return(pareto_results)
}
Expand Down
Loading