-
Notifications
You must be signed in to change notification settings - Fork 709
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial commit * deleted old results classes * switched to torchmetrics framework for evaluation * removed obsolete metrics module * updated tests to accommodate torchmetrics * update name of early stopping metric * re-write patchcore anomaly map generator in torch * remove unused import and update docstrings * ignoring pylint issue caused by bug in torch * fix mypy and pylint issues * update docstrings * opt_f1 -> optimal_f1 * Tensor -> torch.Tensor * add parameter descriptions * enable dfm in model tests * address comments * removed obsolete requirements file * kwds -> kwargs * indexing of trainer results Co-authored-by: Samet <[email protected]>
- Loading branch information
1 parent
eb6bb1a
commit e48835f
Showing
14 changed files
with
165 additions
and
276 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Custom anomaly evaluation metrics.""" | ||
from .auroc import AUROC | ||
from .optimal_f1 import OptimalF1 | ||
|
||
__all__ = ["AUROC", "OptimalF1"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Implementation of AUROC metric based on TorchMetrics.""" | ||
from torch import Tensor | ||
from torchmetrics import ROC | ||
from torchmetrics.functional import auc | ||
|
||
|
||
class AUROC(ROC): | ||
"""Area under the ROC curve.""" | ||
|
||
def compute(self) -> Tensor: | ||
"""First compute ROC curve, then compute area under the curve. | ||
Returns: | ||
Value of the AUROC metric | ||
""" | ||
fpr, tpr, _thresholds = super().compute() | ||
return auc(fpr, tpr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""Implementation of Optimal F1 score based on TorchMetrics.""" | ||
import torch | ||
from torchmetrics import Metric, PrecisionRecallCurve | ||
|
||
|
||
class OptimalF1(Metric): | ||
"""Optimal F1 Metric. | ||
Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the | ||
predicted anomaly scores. | ||
""" | ||
|
||
def __init__(self, num_classes: int, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes, compute_on_step=False) | ||
|
||
self.threshold: torch.Tensor | ||
|
||
# pylint: disable=arguments-differ | ||
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore | ||
"""Update the precision-recall curve metric.""" | ||
self.precision_recall_curve.update(preds, target) | ||
|
||
def compute(self) -> torch.Tensor: | ||
"""Compute the value of the optimal F1 score. | ||
Compute the F1 scores while varying the threshold. Store the optimal | ||
threshold as attribute and return the maximum value of the F1 score. | ||
Returns: | ||
Value of the F1 score at the optimal threshold. | ||
""" | ||
precision: torch.Tensor | ||
recall: torch.Tensor | ||
thresholds: torch.Tensor | ||
|
||
precision, recall, thresholds = self.precision_recall_curve.compute() | ||
f1_score = (2 * precision * recall) / (precision + recall + 1e-10) | ||
self.threshold = thresholds[torch.argmax(f1_score)] | ||
optimal_f1_score = torch.max(f1_score) | ||
return optimal_f1_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.