Skip to content

Commit

Permalink
✨ CS-Flow model (#657)
Browse files Browse the repository at this point in the history
* Use FrEIA from pip + add initial cs-flow model

* Explicitly freeze feature extractor weights

* Fix pre-commit

* Fix pylint issues

* Fix pylint issues

* Rename variable

* Remove requires grad from timm

* Add csflow to tests

* Support two map modes

* Add metrics for cs-flow

* Use the same betas as in paper

* Add model description + images

* TimmFeatureExtractor->FeatureExtractor

* Convert torchfx feature extractor to nn.Module

* Migrate to different markdown linter + format files based on it

* Initial PR feedback changes

* refactor torch_model

* Address changes in lightning_model

* Fix tests

* Fix parameters

* Comments + minor refactor

* Refactor method name

* Add comments to cross-conv block

* Add loss logging to csflow

Co-authored-by: Ashwin Vaidya <[email protected]>
Co-authored-by: Samet Akcay <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2022
1 parent f5893b7 commit 997272a
Show file tree
Hide file tree
Showing 29 changed files with 1,030 additions and 585 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
hooks:
- id: flake8
args: ["--max-line-length=120", "--ignore=E203,W503"]
exclude: "tests|anomalib/models/components/freia"
exclude: "tests"

# python linting
- repo: https://github.com/PyCQA/pylint
Expand All @@ -42,15 +42,15 @@ repos:
entry: pylint --score=no
language: system
types: [python]
exclude: "tests|docs|anomalib/models/components/freia"
exclude: "tests|docs"

# python static type checking
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.971"
hooks:
- id: mypy
additional_dependencies: [types-PyYAML]
exclude: "tests|anomalib/models/components/freia"
exclude: "tests"

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
Expand All @@ -61,7 +61,7 @@ repos:
entry: pydocstyle
language: python
types: [python]
exclude: "tests|docs|anomalib/models/components/freia"
exclude: "tests|docs"

# notebooks.
- repo: https://github.com/nbQA-dev/nbQA
Expand Down
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,16 @@ For more details see the [Discussion forum](https://github.com/openvinotoolkit/a
MVTec AD dataset is one of the main benchmarks for anomaly detection, and is released under the
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License [(CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/).

> Note: These metrics are collected with image size of 256 and seed `42`. This common setting is used to make model comparisons fair.

## Image-Level AUC

| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| ------------- | ------------------ | :-------: | :-------: | :-------: | :-----: | :-------: | :-------: | :-----: | :-------: | :-------: | :------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: |
| **PatchCore** | **Wide ResNet-50** | **0.980** | 0.984 | 0.959 | 1.000 | **1.000** | 0.989 | 1.000 | **0.990** | **0.982** | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | **1.000** | 0.982 |
| PatchCore | ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | **0.943** | 0.931 | 0.996 | 0.953 |
| CFlow | Wide ResNet-50 | 0.962 | 0.986 | 0.962 | **1.0** | 0.999 | **0.993** | **1.0** | 0.893 | 0.945 | **1.0** | **0.995** | 0.924 | 0.908 | 0.897 | 0.943 | **0.984** |
| CS-Flow | EfficientNet-B5 | 0.972 | 0.995 | 0.982 | **1** | 0.972 | 0.988 | **1** | 0.97 | 0.907 | 0.995 | 0.972 | 0.953 | 0.896 | 0.969 | 0.987 | 0.987 |
| PaDiM | Wide ResNet-50 | 0.950 | **0.995** | 0.942 | 1.0 | 0.974 | **0.993** | 0.999 | 0.878 | 0.927 | 0.964 | 0.989 | **0.939** | 0.845 | 0.942 | 0.976 | 0.882 |
| PaDiM | ResNet-18 | 0.891 | 0.945 | 0.857 | 0.982 | 0.950 | 0.976 | 0.994 | 0.844 | 0.901 | 0.750 | 0.961 | 0.863 | 0.759 | 0.889 | 0.920 | 0.780 |
| STFPM | Wide ResNet-50 | 0.876 | 0.957 | 0.977 | 0.981 | 0.976 | 0.939 | 0.987 | 0.878 | 0.732 | 0.995 | 0.973 | 0.652 | 0.825 | 0.5 | 0.875 | 0.899 |
Expand All @@ -334,13 +337,14 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| DFKDE | ResNet-18 | 0.762 | 0.646 | 0.577 | 0.669 | 0.965 | 0.863 | 0.951 | 0.751 | 0.698 | 0.806 | 0.729 | 0.607 | 0.694 | 0.767 | 0.839 | 0.866 |
| GANomaly | | 0.421 | 0.203 | 0.404 | 0.413 | 0.408 | 0.744 | 0.251 | 0.457 | 0.682 | 0.537 | 0.270 | 0.472 | 0.231 | 0.372 | 0.440 | 0.434 |

### Pixel-Level AUC
## Pixel-Level AUC

| Model | | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| ------------- | ------------------ | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :-------: | :--------: | :--------: | :-------: |
| **PatchCore** | **Wide ResNet-50** | **0.980** | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | **0.988** | **0.988** | 0.987 | **0.989** | 0.980 | **0.989** | 0.988 | **0.981** | 0.983 |
| PatchCore | ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |
| CFlow | Wide ResNet-50 | 0.971 | 0.986 | 0.968 | 0.993 | **0.968** | 0.924 | 0.981 | 0.955 | **0.988** | **0.990** | 0.982 | **0.983** | 0.979 | 0.985 | 0.897 | 0.980 |
| CS-Flow | EfficientNet B5 | 0.845 | 0.847 | 0.746 | 0.851 | 0.775 | 0.677 | 0.853 | 0.863 | 0.882 | 0.895 | 0.932 | 0.92 | 0.779 | 0.892 | 0.96 | 0.803 |
| PaDiM | Wide ResNet-50 | 0.979 | **0.991** | 0.970 | 0.993 | 0.955 | **0.957** | **0.985** | 0.970 | **0.988** | 0.985 | 0.982 | 0.966 | 0.988 | **0.991** | 0.976 | **0.986** |
| PaDiM | ResNet-18 | 0.968 | 0.984 | 0.918 | **0.994** | 0.934 | 0.947 | 0.983 | 0.965 | 0.984 | 0.978 | 0.970 | 0.957 | 0.978 | 0.988 | 0.968 | 0.979 |
| STFPM | Wide ResNet-50 | 0.903 | 0.987 | **0.989** | 0.980 | 0.966 | 0.956 | 0.966 | 0.913 | 0.956 | 0.974 | 0.961 | 0.946 | 0.988 | 0.178 | 0.807 | 0.980 |
Expand All @@ -353,6 +357,7 @@ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License
| **PatchCore** | **Wide ResNet-50** | **0.976** | 0.971 | 0.974 | **1.000** | **1.000** | 0.967 | **1.000** | 0.968 | **0.982** | **1.000** | 0.984 | 0.940 | 0.943 | 0.938 | **1.000** | **0.979** |
| PatchCore | ResNet-18 | 0.970 | 0.949 | 0.946 | **1.000** | 0.98 | **0.992** | **1.000** | **0.978** | 0.969 | **1.000** | **0.989** | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |
| CFlow | Wide ResNet-50 | 0.944 | 0.972 | 0.932 | **1.0** | 0.988 | 0.967 | **1.0** | 0.832 | 0.939 | **1.0** | 0.979 | 0.924 | **0.971** | 0.870 | 0.818 | 0.967 |
| CS-Flow | EfficientNet B5 | 0.965 | 0.983 | 0.982 | **1** | 0.957 | 0.966 | **1** | 0.945 | 0.944 | 0.986 | 0.963 | 0.965 | 0.906 | 0.949 | 0.938 | 0.987 |
| PaDiM | Wide ResNet-50 | 0.951 | **0.989** | 0.930 | **1.0** | 0.960 | 0.983 | 0.992 | 0.856 | **0.982** | 0.937 | 0.978 | **0.946** | 0.895 | 0.952 | 0.914 | 0.947 |
| PaDiM | ResNet-18 | 0.916 | 0.930 | 0.893 | 0.984 | 0.934 | 0.952 | 0.976 | 0.858 | 0.960 | 0.836 | 0.974 | 0.932 | 0.879 | 0.923 | 0.796 | 0.915 |
| STFPM | Wide ResNet-50 | 0.926 | 0.973 | 0.973 | 0.974 | 0.965 | 0.929 | 0.976 | 0.853 | 0.920 | 0.972 | 0.974 | 0.922 | 0.884 | 0.833 | 0.815 | 0.931 |
Expand Down
3 changes: 3 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from anomalib.models.cflow import Cflow
from anomalib.models.components import AnomalyModule
from anomalib.models.csflow import Csflow
from anomalib.models.dfkde import Dfkde
from anomalib.models.dfm import Dfm
from anomalib.models.draem import Draem
Expand All @@ -25,6 +26,7 @@

__all__ = [
"Cflow",
"Csflow",
"Dfkde",
"Dfm",
"Draem",
Expand Down Expand Up @@ -73,6 +75,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:

model_list: List[str] = [
"cflow",
"csflow",
"dfkde",
"dfm",
"draem",
Expand Down
5 changes: 2 additions & 3 deletions anomalib/models/cflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@

import numpy as np
import torch
from FrEIA.framework import SequenceINN
from FrEIA.modules import AllInOneBlock
from torch import nn

from anomalib.models.components.freia.framework import SequenceINN
from anomalib.models.components.freia.modules import AllInOneBlock

logger = logging.getLogger(__name__)


Expand Down
5 changes: 3 additions & 2 deletions anomalib/models/components/feature_extractors/torchfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BackboneParams:
init_args: Dict = field(default_factory=dict)


class TorchFXFeatureExtractor:
class TorchFXFeatureExtractor(nn.Module):
"""Extract features from a CNN.
Args:
Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(
weights: Optional[Union[WeightsEnum, str]] = None,
requires_grad: bool = False,
):
super().__init__()
if isinstance(backbone, dict):
backbone = BackboneParams(**backbone)
elif not isinstance(backbone, BackboneParams): # if str or nn.Module
Expand Down Expand Up @@ -169,6 +170,6 @@ def _get_backbone_class(backbone: str) -> Callable[..., nn.Module]:

return backbone_class

def __call__(self, inputs: Tensor) -> Dict[str, Tensor]:
def forward(self, inputs: Tensor) -> Dict[str, Tensor]:
"""Extract features from the input."""
return self.feature_extractor(inputs)
7 changes: 0 additions & 7 deletions anomalib/models/components/freia/README.md

This file was deleted.

16 changes: 0 additions & 16 deletions anomalib/models/components/freia/__init__.py

This file was deleted.

9 changes: 0 additions & 9 deletions anomalib/models/components/freia/framework/__init__.py

This file was deleted.

120 changes: 0 additions & 120 deletions anomalib/models/components/freia/framework/sequence_inn.py

This file was deleted.

10 changes: 0 additions & 10 deletions anomalib/models/components/freia/modules/__init__.py

This file was deleted.

Loading

0 comments on commit 997272a

Please sign in to comment.