diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index 72d8c73c66..ace09f9a7d 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -14,11 +14,6 @@ from anomalib.models.padim.anomaly_map import AnomalyMapGenerator from anomalib.pre_processing import Tiler -_DIMS = { - "resnet18": {"orig_dims": 448, "emb_scale": 4}, - "wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4}, -} - # defaults from the paper _N_FEATURES_DEFAULTS = { "resnet18": 100, @@ -31,6 +26,9 @@ def _deduce_dims( ) -> Tuple[int, int]: """Run a dry run to deduce the dimensions of the extracted features. + Important: `layers` is assumed to be ordered and the first (layers[0]) + is assumed to be the layer with largest resolution. + Returns: Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches) """ @@ -72,31 +70,22 @@ def __init__( self.backbone = backbone self.layers = layers self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) + self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) - if backbone in _DIMS: - backbone_dims = _DIMS[backbone] - self.n_features_original = backbone_dims["orig_dims"] - emb_scale = backbone_dims["emb_scale"] - patches_dims = torch.tensor(input_size) / emb_scale - self.n_patches = patches_dims.ceil().prod().int().item() - - else: - self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers) + n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone) if n_features is None: + raise ValueError( + f"n_features must be specified for backbone {self.backbone}. " + f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" + ) - if self.backbone in _N_FEATURES_DEFAULTS: - n_features = _N_FEATURES_DEFAULTS[self.backbone] - - else: - raise ValueError( - f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. " - f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}" - ) assert ( - n_features <= self.n_features_original - ), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})" + 0 < n_features <= self.n_features_original + ), f"for backbone {self.backbone}, 0 < n_features <= {self.n_features_original}, found {n_features}" + self.n_features = n_features + # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results self.register_buffer(