Skip to content

Commit

Permalink
update utils
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisbader committed Jan 3, 2025
1 parent 8e51e2c commit f3e0db2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 45 deletions.
2 changes: 2 additions & 0 deletions darts/ad/scorers/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ def show_anomalies(
Optionally, the name of the metric function to use. Must be one of "AUC_ROC" (Area Under the
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
multivariate_plot
If True, it will separately plot each component in multivariate series.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
pred_scores = self.score(series)
Expand Down
87 changes: 42 additions & 45 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def show_anomalies_from_scores(
multivariate_plot
If True, it will separately plot each component in multivariate series.
"""

series = _check_input(
series,
name="series",
Expand Down Expand Up @@ -426,78 +425,80 @@ def show_anomalies_from_scores(
)

nbr_plots += len(set(window))
series_width = series.n_components
plots_per_ts = nbr_plots * series_width if multivariate_plot else nbr_plots
fig, axs = plt.subplots(
plots_per_ts,
figsize=(8, 4 + 2 * (plots_per_ts - 1)),
sharex=True,
gridspec_kw={"height_ratios": [2] + [1] * (plots_per_ts - 1)},
squeeze=False,
)

series_width = series.n_components
if pred_series is not None:
pred_series = _check_input(
pred_series,
name="pred_series",
width_expected=series.width,
width_expected=series_width,
num_series_expected=1,
check_multivariate=multivariate_plot,
)[0]

if anomalies is not None:
if anomalies is not None and multivariate_plot:
anomalies = _check_input(
anomalies,
name="anomalies",
width_expected=series.width,
width_expected=series_width,
num_series_expected=1,
check_binary=True,
check_multivariate=multivariate_plot,
)[0]

if pred_scores is not None:
if pred_scores is not None and multivariate_plot:
for pred_score in pred_scores:
pred_score = _check_input(
_ = _check_input(
pred_score,
name="pred_score",
width_expected=series.width,
width_expected=series_width,
num_series_expected=1,
check_multivariate=multivariate_plot,
)[0]

if multivariate_plot:
for i in range(series_width):
_plot_series_and_anomalies(
series=series[series.components[i]],
anomalies=anomalies[anomalies.components[i]]
if anomalies is not None
else None,
pred_series=pred_series[pred_series.components[i]]
plots_per_ts = nbr_plots * series_width if multivariate_plot else nbr_plots
fig, axs = plt.subplots(
plots_per_ts,
figsize=(8, 4 + 2 * (plots_per_ts - 1)),
sharex=True,
gridspec_kw={"height_ratios": [2] + [1] * (plots_per_ts - 1)},
squeeze=False,
)

for i in range(series_width if multivariate_plot else 1):
if multivariate_plot:
series_ = series[series.components[i]]
anomalies_ = (
anomalies[anomalies.components[i]] if anomalies is not None else None
)
pred_series_ = (
pred_series[pred_series.components[i]]
if pred_series is not None
else None,
pred_scores=pred_scores,
window=window,
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=i * nbr_plots,
nbr_plots=nbr_plots,
else None
)
pred_scores_ = (
[pc[pc.components[i]] for pc in pred_scores]
if pred_scores is not None
else None
)
else:
series_ = series
anomalies_ = anomalies
pred_series_ = pred_series
pred_scores_ = pred_scores

else:
_plot_series_and_anomalies(
series=series,
anomalies=anomalies,
pred_series=pred_series,
pred_scores=pred_scores,
series=series_,
anomalies=anomalies_,
pred_series=pred_series_,
pred_scores=pred_scores_,
window=window,
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=0,
index_ax=i * nbr_plots,
nbr_plots=nbr_plots,
)

fig.suptitle(title)


Expand Down Expand Up @@ -838,9 +839,7 @@ def _plot_series_and_anomalies(
value = round(
eval_metric_from_scores(
anomalies=anomalies,
pred_scores=pred_scores[idx][
pred_scores[idx].components[index_ax // nbr_plots]
],
pred_scores=pred_scores[idx],
window=w,
metric=metric,
),
Expand All @@ -855,9 +854,7 @@ def _plot_series_and_anomalies(
label = f"score_{str(idx)}" + [f" ({value})", ""][value is None]

_plot_series(
series=elem[1]["series_score"][
elem[1]["series_score"].components[index_ax // nbr_plots]
],
series=elem[1]["series_score"],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name=label,
Expand Down

0 comments on commit f3e0db2

Please sign in to comment.