Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unexpected key pixel_metrics.AUPRO.fpr_limit #1055

Merged
merged 11 commits into from
Oct 24, 2023
10 changes: 10 additions & 0 deletions src/anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.post_processing import ThresholdMethod
from anomalib.utils.metrics import (
AUPRO,
AnomalibMetricCollection,
AnomalyScoreDistribution,
AnomalyScoreThreshold,
Expand Down Expand Up @@ -234,11 +235,20 @@ def _load_normalization_class(self, state_dict: OrderedDict[str, Tensor]) -> Non
else:
warn("No known normalization found in model weights.")

def _load_pixel_metrics(self, state_dict: OrderedDict[str, Tensor]) -> None:
"""Create pixel_metrics and add AUPRO."""
if not hasattr(self, "pixel_metrics") and "pixel_metrics.AUPRO.fpr_limit" in state_dict.keys():
self.pixel_metrics = AnomalibMetricCollection([], prefix="pixel_")
fpr_limit = state_dict["pixel_metrics.AUPRO.fpr_limit"].item()
self.pixel_metrics.add_metrics(AUPRO(fpr_limit=fpr_limit))
WenjingKangIntel marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict: OrderedDict[str, Tensor], strict: bool = True):
"""Load state dict from checkpoint.

Ensures that normalization and thresholding attributes is properly setup before model is loaded.
"""
# Used to load missing normalization and threshold parameters
self._load_normalization_class(state_dict)
# Used to load pixel metrics before create_metric_collection
self._load_pixel_metrics(state_dict)
return super().load_state_dict(state_dict, strict=strict)
8 changes: 7 additions & 1 deletion src/anomalib/utils/callbacks/metrics_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def setup(

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
if hasattr(pl_module, "pixel_metrics"):
new_metrics = create_metric_collection(pixel_metric_names, "pixel_")
for name in new_metrics.keys():
if name not in pl_module.pixel_metrics.keys():
pl_module.pixel_metrics.add_metrics(new_metrics[name.split("_")[1]])
else:
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")

pl_module.image_metrics.set_threshold(pl_module.image_threshold.value)
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value)