diff --git a/anomalib/models/cflow/torch_model.py b/anomalib/models/cflow/torch_model.py index 2d22d15631..23d3c5d766 100644 --- a/anomalib/models/cflow/torch_model.py +++ b/anomalib/models/cflow/torch_model.py @@ -18,7 +18,6 @@ import einops import torch -import torchvision from torch import nn from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator @@ -44,13 +43,13 @@ def __init__( ): super().__init__() - self.backbone = getattr(torchvision.models, backbone) + self.backbone = backbone self.fiber_batch_size = fiber_batch_size self.condition_vector: int = condition_vector self.dec_arch = decoder self.pool_layers = layers - self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.pool_layers) + self.encoder = FeatureExtractor(backbone=self.backbone, layers=self.pool_layers, pre_trained=pre_trained) self.pool_dims = self.encoder.out_dims self.decoders = nn.ModuleList( [ diff --git a/anomalib/models/components/feature_extractors/feature_extractor.py b/anomalib/models/components/feature_extractors/feature_extractor.py index f616d23280..024e363506 100644 --- a/anomalib/models/components/feature_extractors/feature_extractor.py +++ b/anomalib/models/components/feature_extractors/feature_extractor.py @@ -17,8 +17,10 @@ # See the License for the specific language governing permissions # and limitations under the License. -from typing import Callable, Dict, Iterable +import warnings +from typing import Dict, List +import timm import torch from torch import Tensor, nn @@ -32,10 +34,9 @@ class FeatureExtractor(nn.Module): Example: >>> import torch - >>> import torchvision >>> from anomalib.core.model.feature_extractor import FeatureExtractor - >>> model = FeatureExtractor(model=torchvision.models.resnet18(), layers=['layer1', 'layer2', 'layer3']) + >>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3']) >>> input = torch.rand((32, 3, 256, 256)) >>> features = model(input) @@ -45,42 +46,46 @@ class FeatureExtractor(nn.Module): [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])] """ - def __init__(self, backbone: nn.Module, layers: Iterable[str]): + def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True): super().__init__() self.backbone = backbone self.layers = layers + self.idx = self._map_layer_to_idx() + self.feature_extractor = timm.create_model( + backbone, + pretrained=pre_trained, + features_only=True, + exportable=True, + out_indices=self.idx, + ) + self.out_dims = self.feature_extractor.feature_info.channels() self._features = {layer: torch.empty(0) for layer in self.layers} - self.out_dims = [] - for layer_id in layers: - layer = dict([*self.backbone.named_modules()])[layer_id] - layer.register_forward_hook(self.get_features(layer_id)) - # get output dimension of features if available - layer_modules = [*layer.modules()] - for idx in reversed(range(len(layer_modules))): - if hasattr(layer_modules[idx], "out_channels"): - self.out_dims.append(layer_modules[idx].out_channels) - break - - def get_features(self, layer_id: str) -> Callable: - """Get layer features. + def _map_layer_to_idx(self, offset: int = 3) -> List[int]: + """Maps set of layer names to indices of model. Args: - layer_id (str): Layer ID + offset (int) `timm` ignores the first few layers when indexing please update offset based on need Returns: - Layer features + Feature map extracted from the CNN """ - - def hook(_, __, output): - """Hook to extract features via a forward-pass. - - Args: - output: Feature map collected after the forward-pass. - """ - self._features[layer_id] = output - - return hook + idx = [] + features = timm.create_model( + self.backbone, + pretrained=False, + features_only=False, + exportable=True, + ) + for i in self.layers: + try: + idx.append(list(dict(features.named_children()).keys()).index(i) - offset) + except ValueError: + warnings.warn(f"Layer {i} not found in model {self.backbone}") + # Remove unfound key from layer dict + self.layers.remove(i) + + return idx def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]: """Forward-pass input tensor into the CNN. @@ -91,6 +96,5 @@ def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]: Returns: Feature map extracted from the CNN """ - self._features = {layer: torch.empty(0) for layer in self.layers} - _ = self.backbone(input_tensor) - return self._features + features = dict(zip(self.layers, self.feature_extractor(input_tensor))) + return features diff --git a/anomalib/models/dfkde/config.yaml b/anomalib/models/dfkde/config.yaml index 83a8ffadb7..5c384e8bb1 100644 --- a/anomalib/models/dfkde/config.yaml +++ b/anomalib/models/dfkde/config.yaml @@ -25,7 +25,7 @@ model: threshold_offset: 12 normalization_method: min_max # options: [null, min_max, cdf] layers: - - avgpool + - layer4 metrics: image: diff --git a/anomalib/models/dfkde/torch_model.py b/anomalib/models/dfkde/torch_model.py index 18b6a2b694..4d6a5dfcd2 100644 --- a/anomalib/models/dfkde/torch_model.py +++ b/anomalib/models/dfkde/torch_model.py @@ -19,7 +19,7 @@ from typing import List, Optional, Tuple import torch -import torchvision +import torch.nn.functional as F from torch import Tensor, nn from anomalib.models.components import PCA, FeatureExtractor, GaussianKDE @@ -59,8 +59,8 @@ def __init__( self.threshold_steepness = threshold_steepness self.threshold_offset = threshold_offset - _backbone = getattr(torchvision.models, backbone) - self.feature_extractor = FeatureExtractor(backbone=_backbone(pretrained=pre_trained), layers=layers).eval() + _backbone = backbone + self.feature_extractor = FeatureExtractor(backbone=_backbone, pre_trained=pre_trained, layers=layers).eval() self.pca_model = PCA(n_components=self.n_components) self.kde_model = GaussianKDE() @@ -79,6 +79,10 @@ def get_features(self, batch: Tensor) -> Tensor: """ self.feature_extractor.eval() layer_outputs = self.feature_extractor(batch) + for layer in layer_outputs: + batch_size = len(layer_outputs[layer]) + layer_outputs[layer] = F.adaptive_avg_pool2d(input=layer_outputs[layer], output_size=(1, 1)) + layer_outputs[layer] = layer_outputs[layer].view(batch_size, -1) layer_outputs = torch.cat(list(layer_outputs.values())).detach() return layer_outputs diff --git a/anomalib/models/dfm/torch_model.py b/anomalib/models/dfm/torch_model.py index 89890fc301..f22151dbae 100644 --- a/anomalib/models/dfm/torch_model.py +++ b/anomalib/models/dfm/torch_model.py @@ -18,7 +18,6 @@ import torch import torch.nn.functional as F -import torchvision from torch import Tensor, nn from anomalib.models.components import PCA, DynamicBufferModule, FeatureExtractor @@ -103,13 +102,15 @@ def __init__( score_type: str = "fre", ): super().__init__() - self.backbone = getattr(torchvision.models, backbone) + self.backbone = backbone self.pooling_kernel_size = pooling_kernel_size self.n_components = n_comps self.pca_model = PCA(n_components=self.n_components) self.gaussian_model = SingleClassGaussian() self.score_type = score_type - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=[layer]).eval() + self.feature_extractor = FeatureExtractor( + backbone=self.backbone, pre_trained=pre_trained, layers=[layer] + ).eval() def fit(self, dataset: Tensor) -> None: """Fit a pca transformation and a Gaussian model to dataset. diff --git a/anomalib/models/padim/torch_model.py b/anomalib/models/padim/torch_model.py index fb00f0057b..cf918fdf41 100644 --- a/anomalib/models/padim/torch_model.py +++ b/anomalib/models/padim/torch_model.py @@ -19,7 +19,6 @@ import torch import torch.nn.functional as F -import torchvision from torch import Tensor, nn from anomalib.models.components import FeatureExtractor, MultiVariateGaussian @@ -52,9 +51,9 @@ def __init__( super().__init__() self.tiler: Optional[Tiler] = None - self.backbone = getattr(torchvision.models, backbone) + self.backbone = backbone self.layers = layers - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers) + self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=layers, pre_trained=pre_trained) self.dims = DIMS[backbone] # pylint: disable=not-callable # Since idx is randomly selected, save it with model to get same results @@ -109,7 +108,6 @@ def forward(self, input_tensor: Tensor) -> Tensor: output = self.anomaly_map_generator( embedding=embeddings, mean=self.gaussian.mean, inv_covariance=self.gaussian.inv_covariance ) - return output def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor: diff --git a/anomalib/models/patchcore/torch_model.py b/anomalib/models/patchcore/torch_model.py index b0a9129a93..c7fda1184e 100644 --- a/anomalib/models/patchcore/torch_model.py +++ b/anomalib/models/patchcore/torch_model.py @@ -18,7 +18,6 @@ import torch import torch.nn.functional as F -import torchvision from torch import Tensor, nn from anomalib.models.components import ( @@ -44,12 +43,12 @@ def __init__( super().__init__() self.tiler: Optional[Tiler] = None - self.backbone = getattr(torchvision.models, backbone) + self.backbone = backbone self.layers = layers self.input_size = input_size self.num_neighbors = num_neighbors - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=pre_trained), layers=self.layers) + self.feature_extractor = FeatureExtractor(backbone=self.backbone, pre_trained=pre_trained, layers=self.layers) self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1) self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size) diff --git a/anomalib/models/reverse_distillation/torch_model.py b/anomalib/models/reverse_distillation/torch_model.py index 2177b46bd0..f9cbbe71dd 100644 --- a/anomalib/models/reverse_distillation/torch_model.py +++ b/anomalib/models/reverse_distillation/torch_model.py @@ -16,7 +16,6 @@ from typing import List, Optional, Tuple, Union -import torchvision from torch import Tensor, nn from anomalib.models.components import FeatureExtractor @@ -50,9 +49,8 @@ def __init__( super().__init__() self.tiler: Optional[Tiler] = None - encoder_backbone = getattr(torchvision.models, backbone) - # TODO replace with TIMM feature extractor - self.encoder = FeatureExtractor(backbone=encoder_backbone(pretrained=pre_trained), layers=layers) + encoder_backbone = backbone + self.encoder = FeatureExtractor(backbone=encoder_backbone, pre_trained=pre_trained, layers=layers) self.bottleneck = get_bottleneck_layer(backbone) self.decoder = get_decoder(backbone) diff --git a/anomalib/models/stfpm/torch_model.py b/anomalib/models/stfpm/torch_model.py index ca3485812f..414a495079 100644 --- a/anomalib/models/stfpm/torch_model.py +++ b/anomalib/models/stfpm/torch_model.py @@ -16,7 +16,6 @@ from typing import Dict, List, Optional, Tuple -import torchvision from torch import Tensor, nn from anomalib.models.components import FeatureExtractor @@ -42,9 +41,9 @@ def __init__( super().__init__() self.tiler: Optional[Tiler] = None - self.backbone = getattr(torchvision.models, backbone) - self.teacher_model = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=layers) - self.student_model = FeatureExtractor(backbone=self.backbone(pretrained=False), layers=layers) + self.backbone = backbone + self.teacher_model = FeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers) + self.student_model = FeatureExtractor(backbone=self.backbone, pre_trained=False, layers=layers) # teacher model is fixed for parameters in self.teacher_model.parameters(): diff --git a/configs/model/dfkde.yaml b/configs/model/dfkde.yaml index 29f5da4159..3d7fa26f49 100644 --- a/configs/model/dfkde.yaml +++ b/configs/model/dfkde.yaml @@ -25,7 +25,7 @@ model: backbone: resnet18 pre_trained: true layers: - - avgpool + - layer4 max_training_points: 40000 pre_processing: scale n_components: 16 diff --git a/tests/pre_merge/models/test_feature_extractor.py b/tests/pre_merge/models/test_feature_extractor.py new file mode 100644 index 0000000000..7355e8fb2f --- /dev/null +++ b/tests/pre_merge/models/test_feature_extractor.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from anomalib.models.components.feature_extractors import FeatureExtractor + + +class TestFeatureExtractor: + @pytest.mark.parametrize( + "backbone", + ["resnet18", "wide_resnet50_2"], + ) + @pytest.mark.parametrize( + "pretrained", + [True, False], + ) + def test_feature_extraction(self, backbone, pretrained): + layers = ["layer1", "layer2", "layer3"] + model = FeatureExtractor(backbone=backbone, layers=layers, pre_trained=pretrained) + test_input = torch.rand((32, 3, 256, 256)) + features = model(test_input) + + if backbone == "resnet18": + assert features["layer1"].shape == torch.Size((32, 64, 64, 64)) + assert features["layer2"].shape == torch.Size((32, 128, 32, 32)) + assert features["layer3"].shape == torch.Size((32, 256, 16, 16)) + assert model.out_dims == [64, 128, 256] + assert model.idx == [1, 2, 3] + elif backbone == "wide_resnet50_2": + assert features["layer1"].shape == torch.Size((32, 256, 64, 64)) + assert features["layer2"].shape == torch.Size((32, 512, 32, 32)) + assert features["layer3"].shape == torch.Size((32, 1024, 16, 16)) + assert model.out_dims == [256, 512, 1024] + assert model.idx == [1, 2, 3] + else: + pass