Skip to content

Commit

Permalink
fix: intercept param in outputs + ts_validation plot for convergence
Browse files Browse the repository at this point in the history
  • Loading branch information
laresbernardo authored Feb 6, 2024
2 parents 3fa5268 + b449273 commit 1677723
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.10.5.9013
Version: 3.10.5.9014
Authors@R: c(
person("Gufeng", "Zhou", , "[email protected]", c("cre","aut")),
person("Leonel", "Sentana", , "[email protected]", c("aut")),
Expand Down
1 change: 1 addition & 0 deletions R/R/outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ robyn_outputs <- function(InputCollect, OutputModels,
cores = OutputModels$cores,
iterations = OutputModels$iterations,
trials = OutputModels$trials,
intercept = OutputModels$intercept,
intercept_sign = OutputModels$intercept_sign,
nevergrad_algo = OutputModels$nevergrad_algo,
add_penalty_factor = OutputModels$add_penalty_factor,
Expand Down
18 changes: 8 additions & 10 deletions R/R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,13 @@ robyn_plots <- function(
}
} # End of !hyper_fixed

# Time series and errors convergence validation
get_height <- ceiling(12 * OutputCollect$OutputModels$trials / 3)
if (isTRUE(OutputCollect$OutputModels$ts_validation)) {
ts_validation_plot <- ts_validation(OutputCollect$OutputModels, quiet = TRUE, ...)
all_plots[["ts_validation"]] <- ts_validation(OutputCollect$OutputModels, quiet = TRUE, ...)
if (export) {
ggsave(
paste0(plot_folder, "ts_validation", ".png"),
plot = ts_validation_plot, dpi = 300,
plot = all_plots[["ts_validation"]], dpi = 300,
width = 10, height = get_height, limitsize = FALSE
)
}
Expand Down Expand Up @@ -1400,20 +1401,17 @@ refresh_plots_json <- function(OutputCollectRF, json_file, export = TRUE, ...) {


####################################################################
#' Generate Plots for Time-Series Validation
#' Generate Plots for Time-Series Validation and Convergence
#'
#' Create a plot to visualize the convergence for each of the datasets
#' when time-series validation is enabled when running \code{robyn_run()}.
#' when running \code{robyn_run()}, especially useful for when using ts_validation.
#' As a reference, the closer the test and validation convergence points are,
#' the better, given the time-series wasn't overfitted.
#'
#' @rdname robyn_outputs
#' @return Invisible list with \code{ggplot} plots.
#' @export
ts_validation <- function(OutputModels, quiet = FALSE, ...) {
if (!isTRUE(OutputModels$ts_validation)) {
return(NULL)
}
resultHypParam <- bind_rows(
lapply(OutputModels[
which(names(OutputModels) %in% paste0("trial", seq(OutputModels$trials)))
Expand Down Expand Up @@ -1466,8 +1464,8 @@ ts_validation <- function(OutputModels, quiet = FALSE, ...) {
colour = .data$dataset
# group = as.character(.data$trial)
)) +
geom_point(alpha = 0.2, size = 0.9) +
geom_smooth(method = "gam", formula = y ~ s(x, bs = "cs")) +
geom_point(alpha = 0.2, size = 0.9, na.rm = TRUE) +
geom_smooth(method = "gam", formula = y ~ s(x, bs = "cs"), na.rm = TRUE) +
facet_grid(.data$trial ~ .) +
geom_hline(yintercept = 0, linetype = "dashed") +
labs(y = "NRMSE [Upper 1% Winsorized]", x = "Iteration", colour = "Dataset") +
Expand Down
2 changes: 1 addition & 1 deletion R/man/robyn_outputs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1677723

Please sign in to comment.