Skip to content

Commit

Permalink
only use deduced values
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcbertoldo committed Nov 2, 2022
1 parent db29ba9 commit ffb2190
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler

_DIMS = {
"resnet18": {"orig_dims": 448, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4},
}

# defaults from the paper
_N_FEATURES_DEFAULTS = {
"resnet18": 100,
Expand All @@ -31,6 +26,9 @@ def _deduce_dims(
) -> Tuple[int, int]:
"""Run a dry run to deduce the dimensions of the extracted features.
Important: `layers` is assumed to be ordered and the first (layers[0])
is assumed to be the layer with largest resolution.
Returns:
Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches)
"""
Expand Down Expand Up @@ -72,31 +70,22 @@ def __init__(
self.backbone = backbone
self.layers = layers
self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers)

if backbone in _DIMS:
backbone_dims = _DIMS[backbone]
self.n_features_original = backbone_dims["orig_dims"]
emb_scale = backbone_dims["emb_scale"]
patches_dims = torch.tensor(input_size) / emb_scale
self.n_patches = patches_dims.ceil().prod().int().item()

else:
self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers)
n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone)

if n_features is None:
raise ValueError(
f"n_features must be specified for backbone {self.backbone}. "
f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}"
)

if self.backbone in _N_FEATURES_DEFAULTS:
n_features = _N_FEATURES_DEFAULTS[self.backbone]

else:
raise ValueError(
f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. "
f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}"
)
assert (
n_features <= self.n_features_original
), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})"
0 < n_features <= self.n_features_original
), f"for backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {n_features}"

self.n_features = n_features

# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
self.register_buffer(
Expand Down

0 comments on commit ffb2190

Please sign in to comment.