diff --git a/anomalib/models/components/feature_extractors/feature_extractor.py b/anomalib/models/components/feature_extractors/feature_extractor.py index 6469b37d39..81c13e9c99 100644 --- a/anomalib/models/components/feature_extractors/feature_extractor.py +++ b/anomalib/models/components/feature_extractors/feature_extractor.py @@ -20,6 +20,10 @@ class FeatureExtractor(nn.Module): Args: backbone (nn.Module): The backbone to which the feature extraction hooks are attached. layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached. + pre_trained (bool): Whether to use a pre-trained backbone. Defaults to True. + requires_grad (bool): Whether to require gradients for the backbone. Defaults to False. + Models like ``stfpm`` use the feature extractor model as a trainable network. In such cases gradient + computation is required. Example: >>> import torch @@ -35,11 +39,12 @@ class FeatureExtractor(nn.Module): [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])] """ - def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True): + def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True, requires_grad: bool = False): super().__init__() self.backbone = backbone self.layers = layers self.idx = self._map_layer_to_idx() + self.requires_grad = requires_grad self.feature_extractor = timm.create_model( backbone, pretrained=pre_trained, @@ -85,5 +90,11 @@ def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]: Returns: Feature map extracted from the CNN """ - features = dict(zip(self.layers, self.feature_extractor(input_tensor))) + if self.requires_grad: + features = dict(zip(self.layers, self.feature_extractor(input_tensor))) + else: + self.feature_extractor.eval() + with torch.no_grad(): + features = dict(zip(self.layers, self.feature_extractor(input_tensor))) + return features diff --git a/anomalib/models/stfpm/torch_model.py b/anomalib/models/stfpm/torch_model.py index 669786f45c..8a7da966e0 100644 --- a/anomalib/models/stfpm/torch_model.py +++ b/anomalib/models/stfpm/torch_model.py @@ -32,7 +32,9 @@ def __init__( self.backbone = backbone self.teacher_model = FeatureExtractor(backbone=self.backbone, pre_trained=True, layers=layers) - self.student_model = FeatureExtractor(backbone=self.backbone, pre_trained=False, layers=layers) + self.student_model = FeatureExtractor( + backbone=self.backbone, pre_trained=False, layers=layers, requires_grad=True + ) # teacher model is fixed for parameters in self.teacher_model.parameters():