Skip to content

Commit

Permalink
n_features an arg and deduce other dimensions dinamically if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
jpcbertoldo committed Nov 1, 2022
1 parent 2f0a87c commit db29ba9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
5 changes: 4 additions & 1 deletion anomalib/models/padim/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig, ListConfig
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
input_size: Tuple[int, int],
backbone: str,
pre_trained: bool = True,
n_features: Optional[int] = None,
):
super().__init__()

Expand All @@ -48,6 +49,7 @@ def __init__(
backbone=backbone,
pre_trained=pre_trained,
layers=layers,
n_features=n_features,
).eval()

self.stats: List[Tensor] = []
Expand Down Expand Up @@ -119,6 +121,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
layers=hparams.model.layers,
backbone=hparams.model.backbone,
pre_trained=hparams.model.pre_trained,
n_features=hparams.model.n_features if "n_features" in hparams.model else None,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)
68 changes: 59 additions & 9 deletions anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,39 @@
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler

DIMS = {
"resnet18": {"orig_dims": 448, "reduced_dims": 100, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "reduced_dims": 550, "emb_scale": 4},
_DIMS = {
"resnet18": {"orig_dims": 448, "emb_scale": 4},
"wide_resnet50_2": {"orig_dims": 1792, "emb_scale": 4},
}

# defaults from the paper
_N_FEATURES_DEFAULTS = {
"resnet18": 100,
"wide_resnet50_2": 550,
}


def _deduce_dims(
feature_extractor: FeatureExtractor, input_size: Tuple[int, int], layers: List[str]
) -> Tuple[int, int]:
"""Run a dry run to deduce the dimensions of the extracted features.
Returns:
Tuple[int, int]: Dimensions of the extracted features: (n_dims_original, n_patches)
"""

dryrun_input = torch.empty(1, 3, *input_size)
dryrun_features = feature_extractor(dryrun_input)

# the first layer in `layers` is the largest spatial size
dryrun_emb_first_layer = dryrun_features[layers[0]]
n_patches = torch.tensor(dryrun_emb_first_layer.shape[-2:]).prod().int().item()

# the original embedding size is the sum of the channels of all layers
n_features_original = sum(dryrun_features[layer].shape[1] for layer in layers)

return n_features_original, n_patches


class PadimModel(nn.Module):
"""Padim Module.
Expand All @@ -36,28 +64,50 @@ def __init__(
layers: List[str],
backbone: str = "resnet18",
pre_trained: bool = True,
n_features: Optional[int] = None,
):
super().__init__()
self.tiler: Optional[Tiler] = None

self.backbone = backbone
self.layers = layers
self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained)
self.dims = DIMS[backbone]

if backbone in _DIMS:
backbone_dims = _DIMS[backbone]
self.n_features_original = backbone_dims["orig_dims"]
emb_scale = backbone_dims["emb_scale"]
patches_dims = torch.tensor(input_size) / emb_scale
self.n_patches = patches_dims.ceil().prod().int().item()

else:
self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers)

if n_features is None:

if self.backbone in _N_FEATURES_DEFAULTS:
n_features = _N_FEATURES_DEFAULTS[self.backbone]

else:
raise ValueError(
f"{self.__class__.__name__}.n_features must be specified for backbone {self.backbone}. "
f"Default values are available for: {sorted(_N_FEATURES_DEFAULTS.keys())}"
)
assert (
n_features <= self.n_features_original
), f"n_features ({n_features}) must be <= n_features_original ({self.n_features_original})"
self.n_features = n_features
# pylint: disable=not-callable
# Since idx is randomly selected, save it with model to get same results
self.register_buffer(
"idx",
torch.tensor(sample(range(0, DIMS[backbone]["orig_dims"]), DIMS[backbone]["reduced_dims"])),
torch.tensor(sample(range(0, self.n_features_original), self.n_features)),
)
self.idx: Tensor
self.loss = None
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size)

n_features = DIMS[backbone]["reduced_dims"]
patches_dims = torch.tensor(input_size) / DIMS[backbone]["emb_scale"]
n_patches = patches_dims.ceil().prod().int().item()
self.gaussian = MultiVariateGaussian(n_features, n_patches)
self.gaussian = MultiVariateGaussian(self.n_features, self.n_patches)

def forward(self, input_tensor: Tensor) -> Tensor:
"""Forward-pass image-batch (N, C, H, W) into model to extract features.
Expand Down

0 comments on commit db29ba9

Please sign in to comment.