diff --git a/anomalib/core/callbacks/visualizer_callback.py b/anomalib/core/callbacks/visualizer_callback.py index 4a228afbac..5c15031cae 100644 --- a/anomalib/core/callbacks/visualizer_callback.py +++ b/anomalib/core/callbacks/visualizer_callback.py @@ -1,16 +1,16 @@ """Visualizer Callback.""" from pathlib import Path +from typing import Any, Optional from warnings import warn -from pytorch_lightning import Callback, LightningModule, Trainer +import pytorch_lightning as pl +from pytorch_lightning import Callback +from pytorch_lightning.utilities.types import STEP_OUTPUT from skimage.segmentation import mark_boundaries -from tqdm import tqdm from anomalib import loggers from anomalib.core.model import AnomalyModule -from anomalib.core.results import SegmentationResults from anomalib.data.transforms import Denormalize -from anomalib.utils.metrics import compute_threshold_and_f1_score from anomalib.utils.post_process import compute_mask, superimpose_anomaly_map from anomalib.utils.visualizer import Visualizer @@ -57,33 +57,37 @@ def _add_images( if "local" in module.hparams.project.log_images_to: visualizer.save(Path(module.hparams.project.path) / "images" / filename.parent.name / filename.name) - def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None: - """Log images at the end of training. + def on_test_batch_end( + self, + _trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: Optional[STEP_OUTPUT], + _batch: Any, + _batch_idx: int, + _dataloader_idx: int, + ) -> None: + """Log images at the end of every batch. Args: - _trainer (Trainer): Pytorch lightning trainer object (unused) + _trainer (Trainer): Pytorch lightning trainer object (unused). pl_module (LightningModule): Lightning modules derived from BaseAnomalyLightning object as currently only they support logging images. + outputs (Dict[str, Any]): Outputs of the current test step. + _batch (Any): Input batch of the current test step (unused). + _batch_idx (int): Index of the current test batch (unused). + _dataloader_idx (int): Index of the dataloader that yielded the current batch (unused). """ - if isinstance(pl_module.results, SegmentationResults): - results = pl_module.results - else: - raise ValueError("Visualizer callback only supported for segmentation tasks.") - - if results.images is None or results.true_masks is None or results.anomaly_maps is None: - raise ValueError("Result set cannot be empty!") - - threshold, _ = compute_threshold_and_f1_score(results.true_masks, results.anomaly_maps) + assert outputs is not None - for (filename, image, true_mask, anomaly_map) in tqdm( - zip(results.filenames, results.images, results.true_masks, results.anomaly_maps), - desc="Saving Results", - total=len(results.filenames), + for (filename, image, true_mask, anomaly_map) in zip( + outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"] ): - image = Denormalize()(image) + image = Denormalize()(image.cpu()) + true_mask = true_mask.cpu().numpy() + anomaly_map = anomaly_map.cpu().numpy() heat_map = superimpose_anomaly_map(anomaly_map, image) - pred_mask = compute_mask(anomaly_map, threshold) + pred_mask = compute_mask(anomaly_map, pl_module.threshold.item()) vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick") visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3)) @@ -92,5 +96,5 @@ def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> No visualizer.add_image(image=heat_map, title="Predicted Heat Map") visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask") visualizer.add_image(image=vis_img, title="Segmentation Result") - self._add_images(visualizer, pl_module, filename) + self._add_images(visualizer, pl_module, Path(filename)) visualizer.close() diff --git a/anomalib/core/metrics/__init__.py b/anomalib/core/metrics/__init__.py new file mode 100644 index 0000000000..56cd5e0725 --- /dev/null +++ b/anomalib/core/metrics/__init__.py @@ -0,0 +1,5 @@ +"""Custom anomaly evaluation metrics.""" +from .auroc import AUROC +from .optimal_f1 import OptimalF1 + +__all__ = ["AUROC", "OptimalF1"] diff --git a/anomalib/core/metrics/auroc.py b/anomalib/core/metrics/auroc.py new file mode 100644 index 0000000000..35165a9090 --- /dev/null +++ b/anomalib/core/metrics/auroc.py @@ -0,0 +1,17 @@ +"""Implementation of AUROC metric based on TorchMetrics.""" +from torch import Tensor +from torchmetrics import ROC +from torchmetrics.functional import auc + + +class AUROC(ROC): + """Area under the ROC curve.""" + + def compute(self) -> Tensor: + """First compute ROC curve, then compute area under the curve. + + Returns: + Value of the AUROC metric + """ + fpr, tpr, _thresholds = super().compute() + return auc(fpr, tpr) diff --git a/anomalib/core/metrics/optimal_f1.py b/anomalib/core/metrics/optimal_f1.py new file mode 100644 index 0000000000..3b47be5984 --- /dev/null +++ b/anomalib/core/metrics/optimal_f1.py @@ -0,0 +1,42 @@ +"""Implementation of Optimal F1 score based on TorchMetrics.""" +import torch +from torchmetrics import Metric, PrecisionRecallCurve + + +class OptimalF1(Metric): + """Optimal F1 Metric. + + Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the + predicted anomaly scores. + """ + + def __init__(self, num_classes: int, **kwargs): + super().__init__(**kwargs) + + self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes, compute_on_step=False) + + self.threshold: torch.Tensor + + # pylint: disable=arguments-differ + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore + """Update the precision-recall curve metric.""" + self.precision_recall_curve.update(preds, target) + + def compute(self) -> torch.Tensor: + """Compute the value of the optimal F1 score. + + Compute the F1 scores while varying the threshold. Store the optimal + threshold as attribute and return the maximum value of the F1 score. + + Returns: + Value of the F1 score at the optimal threshold. + """ + precision: torch.Tensor + recall: torch.Tensor + thresholds: torch.Tensor + + precision, recall, thresholds = self.precision_recall_curve.compute() + f1_score = (2 * precision * recall) / (precision + recall + 1e-10) + self.threshold = thresholds[torch.argmax(f1_score)] + optimal_f1_score = torch.max(f1_score) + return optimal_f1_score diff --git a/anomalib/core/model/anomaly_module.py b/anomalib/core/model/anomaly_module.py index 98f6e3492c..d3b477665f 100644 --- a/anomalib/core/model/anomaly_module.py +++ b/anomalib/core/model/anomaly_module.py @@ -21,9 +21,9 @@ from omegaconf import DictConfig, ListConfig from pytorch_lightning.callbacks.base import Callback from torch import nn +from torchmetrics import F1, MetricCollection -from anomalib.core.results import ClassificationResults, SegmentationResults -from anomalib.utils.metrics import compute_threshold_and_f1_score +from anomalib.core.metrics import AUROC, OptimalF1 class AnomalyModule(pl.LightningModule): @@ -41,18 +41,21 @@ def __init__(self, params: Union[DictConfig, ListConfig]): self.save_hyperparameters(params) self.loss: torch.Tensor self.callbacks: List[Callback] - self.register_buffer("threshold", torch.Tensor([params.model.threshold.default])) + self.register_buffer("threshold", torch.tensor(params.model.threshold.default)) # pylint: disable=not-callable self.threshold: torch.Tensor self.model: nn.Module - self.results: Union[ClassificationResults, SegmentationResults] - if params.dataset.task == "classification": - self.results = ClassificationResults() - elif params.dataset.task == "segmentation": - self.results = SegmentationResults() + # metrics + self.image_metrics = MetricCollection( + [AUROC(num_classes=1, pos_label=1, compute_on_step=False)], prefix="image_" + ) + if params.model.threshold.adaptive: + self.image_metrics.add_metrics([OptimalF1(num_classes=1)]) else: - raise NotImplementedError("Only Classification and Segmentation tasks are supported in this version.") + self.image_metrics.add_metrics([F1(num_classes=1, compute_on_step=False, threshold=self.threshold.item())]) + if self.hparams.dataset.task == "segmentation": + self.pixel_metrics = self.image_metrics.clone(prefix="pixel_") def forward(self, batch): # pylint: disable=arguments-differ """Forward-pass input tensor to the module. @@ -96,33 +99,33 @@ def test_step(self, batch, _): # pylint: disable=arguments-differ def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ """Called at the end of each validation step.""" - return self._post_process(val_step_outputs) + val_step_outputs = self._post_process(val_step_outputs) + self.image_metrics(val_step_outputs["pred_scores"], val_step_outputs["label"].int()) + if self.hparams.dataset.task == "segmentation": + self.pixel_metrics(val_step_outputs["anomaly_maps"].flatten(), val_step_outputs["mask"].flatten().int()) + return val_step_outputs def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ - """Called at the end of each validation step.""" - return self._post_process(test_step_outputs) + """Called at the end of each test step.""" + return self.validation_step_end(test_step_outputs) - def validation_epoch_end(self, outputs): - """Compute image-level performance metrics. + def validation_epoch_end(self, _outputs): + """Compute threshold and performance metrics. Args: outputs: Batch of outputs from the validation step """ - self.results.store_outputs(outputs) if self.hparams.model.threshold.adaptive: - threshold, _ = compute_threshold_and_f1_score(self.results.true_labels, self.results.pred_scores) - self.threshold = torch.Tensor([threshold]) - self.results.evaluate(self.threshold.item()) + self.image_metrics.compute() + self.threshold = self.image_metrics.OptimalF1.threshold self._log_metrics() - def test_epoch_end(self, outputs): + def test_epoch_end(self, _outputs): """Compute and save anomaly scores of the test set. Args: outputs: Batch of outputs from the validation step """ - self.results.store_outputs(outputs) - self.results.evaluate(self.threshold.item()) self._log_metrics() def _post_process(self, outputs, predict_labels=False): @@ -137,5 +140,6 @@ def _post_process(self, outputs, predict_labels=False): def _log_metrics(self): """Log computed performance metrics.""" - for name, value in self.results.performance.items(): - self.log(name=name, value=value, on_epoch=True, prog_bar=True) + self.log_dict(self.image_metrics) + if self.hparams.dataset.task == "segmentation": + self.log_dict(self.pixel_metrics) diff --git a/anomalib/core/model/kde.py b/anomalib/core/model/kde.py index 0608f707f7..84c62afef4 100644 --- a/anomalib/core/model/kde.py +++ b/anomalib/core/model/kde.py @@ -49,7 +49,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: """ features = torch.matmul(features, self.bw_transform) - estimate = torch.zeros(features.shape[0]) + estimate = torch.zeros(features.shape[0]).to(features.device) for i in range(features.shape[0]): embedding = ((self.dataset - features[i]) ** 2).sum(dim=1) embedding = torch.exp(-embedding / 2) * self.norm diff --git a/anomalib/core/results/__init__.py b/anomalib/core/results/__init__.py deleted file mode 100644 index 0ccb027fdc..0000000000 --- a/anomalib/core/results/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""This module contains Result dataclass objects to store classification and segmentation results.""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from .results import ClassificationResults, SegmentationResults - -__all__ = ["ClassificationResults", "SegmentationResults"] diff --git a/anomalib/core/results/results.py b/anomalib/core/results/results.py deleted file mode 100644 index 73fbb68911..0000000000 --- a/anomalib/core/results/results.py +++ /dev/null @@ -1,110 +0,0 @@ -"""This module defines Result Sets.""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import torch -from sklearn.metrics import balanced_accuracy_score, f1_score, roc_auc_score -from torch import Tensor - - -@dataclass -class ClassificationResults: - """Dataclass to store classification-task results. - - A classification task would return a anomaly classification score, which is used to compute - the overall performance by comparing it with the true_labels (ground-truth). - - Args: - filenames: List[Union[str, Path]] - images: List[Union[np.ndarray, Tensor]] - true_labels: List[Union[Tensor, np.ndarray]] - anomaly_scores: List[Union[Tensor, np.ndarray]] - performance: Dict[str, Any] - - Examples: - >>> from anomalib.core.results import ClassificationResult - >>> ClassificationResult() - ClassificationResult( - filenames=[], images=[], - true_labels=[], anomaly_scores=[], - performance={} - ) - """ - - filenames: List[Union[str, Path]] = field(default_factory=list) - images: Optional[Tensor] = None - true_labels: np.ndarray = np.empty(0) - pred_scores: np.ndarray = np.empty(0) - pred_labels: np.ndarray = np.empty(0) - # TODO: Use MetricCollection: https://jira.devtools.intel.com/browse/IAAALD-170 - performance: Dict[str, Any] = field(default_factory=dict) - - def store_outputs(self, outputs: List[dict]): - """Concatenate the outputs from the individual batches and store in the result set.""" - if "image_path" in outputs[0].keys(): - self.filenames = [Path(f) for x in outputs for f in x["image_path"]] - self.images = torch.vstack([x["image"] for x in outputs]) - self.true_labels = np.hstack([output["label"].cpu() for output in outputs]) - self.pred_scores = np.hstack([output["pred_scores"].cpu() for output in outputs]) - - def evaluate(self, threshold: float): - """Compute performance metrics.""" - self.pred_labels = self.pred_scores >= threshold - self.performance["image_f1_score"] = f1_score(self.true_labels, self.pred_labels) - self.performance["balanced_accuracy_score"] = balanced_accuracy_score(self.true_labels, self.pred_labels) - self.performance["image_roc_auc"] = roc_auc_score(self.true_labels, self.pred_scores) - - -@dataclass -class SegmentationResults(ClassificationResults): - """Dataclass to store segmentation-based task results. - - An anomaly segmentation task returns anomaly maps in addition to anomaly scores, which are then used to - compute anomaly masks to compare against the true segmentation masks. - - Args: - anomaly_maps: List[Union[np.ndarray, Tensor]] - true_masks: List[Union[np.ndarray, Tensor]] - pred_masks: List[Union[np.ndarray, Tensor]] - - Example: - >>> from anomalib.core.results import SegmentationResult - >>> SegmentationResult() - SegmentationResult( - true_labels=[], anomaly_scores=[], performance={}, - anomaly_maps=[], true_masks=[], - pred_masks=[] - ) - """ - - anomaly_maps: np.ndarray = np.empty(0) - true_masks: np.ndarray = np.empty(0) - pred_masks: Optional[np.ndarray] = None - - def store_outputs(self, outputs: List[dict]): - """Concatenate the outputs from the individual batches and store in the result set.""" - super().store_outputs(outputs) - self.true_masks = np.vstack([output["mask"].squeeze(1).cpu() for output in outputs]) - self.anomaly_maps = np.vstack([output["anomaly_maps"].cpu() for output in outputs]) - - def evaluate(self, threshold: float): - """First compute common metrics, then compute segmentation-specific metrics.""" - super().evaluate(threshold) - self.performance["pixel_roc_auc"] = roc_auc_score(self.true_masks.flatten(), self.anomaly_maps.flatten()) diff --git a/anomalib/models/patchcore/model.py b/anomalib/models/patchcore/model.py index 9766b569e7..d5d62936dc 100644 --- a/anomalib/models/patchcore/model.py +++ b/anomalib/models/patchcore/model.py @@ -16,13 +16,11 @@ from typing import Dict, List, Optional, Tuple, Union -import cv2 -import numpy as np import torch import torch.nn.functional as F import torchvision +from kornia import gaussian_blur2d from omegaconf import ListConfig -from scipy.ndimage import gaussian_filter from torch import Tensor, nn from anomalib.core.model import AnomalyModule @@ -47,31 +45,37 @@ def __init__( self.input_size = input_size self.sigma = sigma - def compute_anomaly_map(self, score_patches: np.ndarray) -> np.ndarray: + def compute_anomaly_map(self, patch_scores: torch.Tensor) -> torch.Tensor: """Pixel Level Anomaly Heatmap. Args: - score_patches (np.ndarray): [description] + patch_scores (torch.Tensor): Patch-level anomaly scores + Returns: + torch.Tensor: Map of the pixel-level anomaly scores """ - anomaly_map = score_patches[:, 0].reshape((28, 28)) - anomaly_map = cv2.resize(anomaly_map, self.input_size) - anomaly_map = gaussian_filter(anomaly_map, sigma=self.sigma) + anomaly_map = patch_scores[:, 0].reshape((1, 1, 28, 28)) + anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1])) + + kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1 + anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(self.sigma, self.sigma)) return anomaly_map @staticmethod - def compute_anomaly_score(patch_scores: np.ndarray) -> np.ndarray: + def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor: """Compute Image-Level Anomaly Score. Args: - patch_scores (np.ndarray): [description] + patch_scores (torch.Tensor): Patch-level anomaly scores + Returns: + torch.Tensor: Image-level anomaly scores """ - confidence = patch_scores[np.argmax(patch_scores[:, 0])] - weights = 1 - (np.max(np.exp(confidence)) / np.sum(np.exp(confidence))) + confidence = patch_scores[torch.argmax(patch_scores[:, 0])] + weights = 1 - (torch.max(torch.exp(confidence)) / torch.sum(torch.exp(confidence))) score = weights * max(patch_scores[:, 0]) return score - def __call__(self, **kwds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Returns anomaly_map and anomaly_score. Expects `patch_scores` keyword to be passed explicitly @@ -84,13 +88,13 @@ def __call__(self, **kwds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: ValueError: If `patch_scores` key is not found Returns: - Tuple[np.ndarray, np.ndarray]: anomaly_map, anomaly_score + Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score """ - if "patch_scores" not in kwds: - raise ValueError(f"Expected key `patch_scores`. Found {kwds.keys()}") + if "patch_scores" not in kwargs: + raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}") - patch_scores: np.ndarray = kwds["patch_scores"].cpu().numpy() + patch_scores = kwargs["patch_scores"] anomaly_map = self.compute_anomaly_map(patch_scores) anomaly_score = self.compute_anomaly_score(patch_scores) return anomaly_map, anomaly_score @@ -302,6 +306,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ """ anomaly_maps, _ = self.model(batch["image"]) - batch["anomaly_maps"] = torch.Tensor(anomaly_maps).unsqueeze(0).unsqueeze(0) + batch["anomaly_maps"] = anomaly_maps return batch diff --git a/anomalib/models/stfpm/config.yaml b/anomalib/models/stfpm/config.yaml index 4f84a1735e..ed93493c34 100644 --- a/anomalib/models/stfpm/config.yaml +++ b/anomalib/models/stfpm/config.yaml @@ -31,7 +31,7 @@ model: weight_decay: 0.0001 early_stopping: patience: 3 - metric: pixel_roc_auc + metric: pixel_AUROC mode: max threshold: default: 0 diff --git a/anomalib/utils/metrics.py b/anomalib/utils/metrics.py deleted file mode 100644 index 80be27e05e..0000000000 --- a/anomalib/utils/metrics.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Metrics This module contains metric-related util functions.""" - -# Copyright (C) 2020 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. - - -from typing import Tuple, Union - -import numpy as np -from sklearn.metrics import precision_recall_curve -from torch import Tensor - - -def compute_threshold_and_f1_score( - ground_truth: Union[Tensor, np.ndarray], predictions: Union[Tensor, np.ndarray] -) -> Tuple[float, float]: - """Compute adaptive threshold, based on the f1 metric of the true labels and the predicted anomaly scores. - - Args: - ground_truth: Pixel-level or image-level ground truth labels. - predictions: Anomaly scores predicted by the model. - - Examples: - >>> import numpy as np - >>> y_true = np.array([0, 0, 1, 1]) - >>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) - - >>> compute_adaptive_threshold(y_true, y_scores) - (0.35, 0.8) - - Returns: - Threshold value based on the best f1 score. - Value of the best f1 score. - """ - - precision, recall, thresholds = precision_recall_curve(ground_truth.flatten(), predictions.flatten()) - f1_score = (2 * precision * recall) / (precision + recall + 1e-10) - threshold = thresholds[np.argmax(f1_score)] - max_f1_score = np.max(f1_score) - - return threshold, max_f1_score diff --git a/tests/core/callbacks/visualizer_callback/dummy_lightning_model.py b/tests/core/callbacks/visualizer_callback/dummy_lightning_model.py index 3abc43e597..04e74cc595 100644 --- a/tests/core/callbacks/visualizer_callback/dummy_lightning_model.py +++ b/tests/core/callbacks/visualizer_callback/dummy_lightning_model.py @@ -12,7 +12,6 @@ from anomalib.core.callbacks.visualizer_callback import VisualizerCallback from anomalib.core.model import AnomalyModule -from anomalib.core.results import SegmentationResults class DummyDataset(Dataset): @@ -52,19 +51,18 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): self.model = DummyModel() self.task = "segmentation" self.callbacks = [VisualizerCallback()] # test if this is removed - self.results.filenames = [Path("test1.jpg"), Path("test2.jpg")] - - if isinstance(self.results, SegmentationResults): - self.results.images = [torch.rand((1, 3, 100, 100))] * 2 - self.results.true_masks = np.zeros((2, 100, 100)) - self.results.anomaly_maps = np.ones((2, 100, 100)) def test_step(self, batch, _): """Only used to trigger on_test_epoch_end.""" self.log(name="loss", value=0.0, prog_bar=True) - - def test_step_end(self, test_step_outputs): - return None + outputs = dict( + image_path=[Path("test1.jpg")], + image=torch.rand((1, 3, 100, 100)), + mask=torch.zeros((1, 100, 100)), + anomaly_maps=torch.ones((1, 100, 100)), + label=torch.Tensor([0]), + ) + return outputs def validation_epoch_end(self, output): return None diff --git a/tests/core/callbacks/visualizer_callback/test_visualizer.py b/tests/core/callbacks/visualizer_callback/test_visualizer.py index 200c91ea8f..e847d414d1 100644 --- a/tests/core/callbacks/visualizer_callback/test_visualizer.py +++ b/tests/core/callbacks/visualizer_callback/test_visualizer.py @@ -37,7 +37,7 @@ def test_add_images(dataset): trainer = pl.Trainer(callbacks=model.callbacks, logger=logger, checkpoint_callback=False) trainer.test(model=model, datamodule=DummyDataModule()) # test if images are logged - if len(glob.glob(os.path.join(dir_loc, "images", "*.jpg"))) != 2: + if len(glob.glob(os.path.join(dir_loc, "images", "*.jpg"))) != 1: raise Exception("Failed to save to local path") # test if tensorboard logs are created diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 4ca8492e5b..9fd6ed1da1 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -116,14 +116,15 @@ def _setup(self, model_name, use_mvtec, dataset_path, project_path, nncf, catego def _test_metrics(self, trainer, config, model, datamodule): """Tests the model metrics but also acts as a setup.""" - trainer.test(model=model, datamodule=datamodule) + results = trainer.test(model=model, datamodule=datamodule)[0] - assert model.results.performance["image_roc_auc"] >= 0.6 + assert results["image_AUROC"] >= 0.6 if config.dataset.task == "segmentation": - assert model.results.performance["pixel_roc_auc"] >= 0.6 + assert results["pixel_AUROC"] >= 0.6 + return results - def _test_model_load(self, config, datamodule, model): + def _test_model_load(self, config, datamodule, results): loaded_model = get_model(config) # get new model callbacks = get_callbacks(config) @@ -137,19 +138,14 @@ def _test_model_load(self, config, datamodule, model): # create new trainer object with LoadModel callback (assumes it is present) trainer = Trainer(callbacks=callbacks, **config.trainer) # Assumes the new model has LoadModel callback and the old one had ModelCheckpoint callback - trainer.test(model=loaded_model, datamodule=datamodule) - - # Common for both classification and segmentation - is_close = np.isclose( - model.results.performance["image_roc_auc"], loaded_model.results.performance["image_roc_auc"] - ) - assert is_close, "Loaded model does not yield close performance results" - + new_results = trainer.test(model=loaded_model, datamodule=datamodule)[0] + assert np.isclose( + results["image_AUROC"], new_results["image_AUROC"] + ), "Loaded model does not yield close performance results" if config.dataset.task == "segmentation": - is_close = np.isclose( - model.results.performance["pixel_roc_auc"], loaded_model.results.performance["pixel_roc_auc"] - ) - assert is_close, "Loaded model does not yield close performance results" + assert np.isclose( + results["pixel_AUROC"], new_results["pixel_AUROC"] + ), "Loaded model does not yield close performance results" @pytest.mark.parametrize( ["model_name", "nncf"], @@ -179,7 +175,7 @@ def test_model(self, category, model_name, nncf, use_mvtec=True, path="./dataset ) # test model metrics - self._test_metrics(trainer=trainer, config=config, model=model, datamodule=datamodule) + results = self._test_metrics(trainer=trainer, config=config, model=model, datamodule=datamodule) # test model load - self._test_model_load(config=config, datamodule=datamodule, model=model) + self._test_model_load(config=config, datamodule=datamodule, results=results)