Skip to content

Commit

Permalink
🚀 Replace albumentations with torchvision transforms (#1706)
Browse files Browse the repository at this point in the history
* replace albumentations transforms with torchvision

* define model-specific transform

* include transforms in exported model

* update inferencers

* fix data types

* remove old import

* update synthetic dataset

* formatting

* address pre-commit issues

* rename methods

* add todo

* add transform arguments to datamodule

* simplify transform retrieval

* make transform optional

* update folder dataset

* read image as float in torch inferencer

* Compose -> Transform

* fix tests

* fix folder tests

* remove image_size from config

* fix visualizer tests

* Compose -> Transform

* update license headers

* disable argument linking for image size

* add new engine method for updating transform

* default_transform -> configure_transforms

* update padim

* add image_size parameter to datamodule

* improve logic in transform setup

* simplify datamodule/dataset setup

* add unit tests for transform setup

* update tests

* remove get_transforms

* remove references to old transforms

* disable antialias when exporting to onnx/openvino

* add default transforms for patchcore and efficientad

* add model-specific transform to ai-vad

* add default transform to base model

* call setup trainer in export entrypoint

* add model-specific transforms to winclip model

* add default transform for uflow model

* fix centercrop export

* fix model and tools integration tests

* update cli export test

* typing and docstrings

* make transforms argument optional in datasets

* remove commented code

* use conditional formatting

* read anomaly source images with PIL

* fix synthetic anomaly tests

* docstring

* typing and docstrings

* remove whitespace

* use LabelName

* update and use read_image

* replace cv2.imread

* Fix tests

* update config upgrade tests

* dynamically set input size parameter

* fix minor mistakes

* include transform in checkpoint

* fix get_model tests

* fix viz callback tests

* change workflow to address minor issues

* update setup_transform tests

* pass explicit arguments

* type annotation for torch model

* use hasatr

* use getattr instead of hasattr

* update unittest names

* remove input_size argument from padim model

* remove image_size argument from patchcore model

* remove image_size argument from cfa model

* remove image_size argument from cflow model

* remove image_size argument from dfa model

* remove image_size argument from efficient_ad model

* replace albumentations with torchvision in efficient_ad

* use _setup to build csflow torch model

* use _setup to build fastflow torch model

* use _setup to build ganomaly torch model

* remove input_size parameter from stfpm model

* use _setup to build revdist torch model

* use _setup to build uflow torch model

* add read_mask function

* replace to_tensor

* allow default model transform in export

* fix trainer availability check

* setup transforms in dataloaders during predict

* update cli tests

* rename updated class

* remove reference to albumentations from docstrings

* remove albumentations from conftest

* revert to using padim for predict tests

* fix cli predict on PredictDataset bug

* add image_size parameter to folder datamodule

* update notebooks

* remove albumentations from requirements

* fix torch inference test

* read mask as uint8

* create default resize transform in datamodule

* fix fastflow notebook

* pass image_size to folder datamodule in notebook

* reduce num workers in notebooks

* read numpy image in [0-1] range

* use read_image in openvino inference

* update getting_started notebook

* enable antialias in default transform

* use read_image in PredictDataset

* fix fastflow notebook

---------

Co-authored-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
djdameln and ashwinvaidya17 authored Feb 29, 2024
1 parent fa5ac1a commit 80c0756
Show file tree
Hide file tree
Showing 85 changed files with 3,238 additions and 2,164 deletions.
441 changes: 382 additions & 59 deletions notebooks/000_getting_started/001_getting_started.ipynb

Large diffs are not rendered by default.

297 changes: 249 additions & 48 deletions notebooks/100_datamodules/101_btech.ipynb

Large diffs are not rendered by default.

281 changes: 235 additions & 46 deletions notebooks/100_datamodules/102_mvtec.ipynb

Large diffs are not rendered by default.

529 changes: 478 additions & 51 deletions notebooks/100_datamodules/103_folder.ipynb

Large diffs are not rendered by default.

635 changes: 210 additions & 425 deletions notebooks/200_models/201_fastflow.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion requirements/core.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
albumentations>=1.1.0
av>=10.0.0
einops>=0.3.2
freia>=0.2
Expand Down
21 changes: 12 additions & 9 deletions src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
_LIGHTNING_AVAILABLE = True
try:
from lightning.pytorch import Trainer
from lightning.pytorch.core.datamodule import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from anomalib.data import AnomalibDataModule, AnomalibDataset
from anomalib.data import AnomalibDataModule
from anomalib.data.predict import PredictDataset
from anomalib.engine import Engine
from anomalib.metrics.threshold import BaseThreshold
Expand Down Expand Up @@ -158,8 +157,7 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
if hasattr(parser, "subcommand") and parser.subcommand != "predict": # Predict also accepts str and Path inputs
parser.link_arguments("data.init_args.image_size", "model.init_args.input_size")
if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"):
parser.link_arguments("task", "data.init_args.task")
parser.add_argument(
"--results_dir.path",
Expand Down Expand Up @@ -242,11 +240,10 @@ def add_export_arguments(self, parser: ArgumentParser) -> None:
fail_untyped=False,
required=True,
)
parser.add_subclass_arguments((AnomalibDataModule, AnomalibDataset), "data")
added = parser.add_method_arguments(
Engine,
"export",
skip={"mo_args", "datamodule", "dataset", "model"},
skip={"mo_args", "model"},
)
self.subcommand_method_arguments["export"] = added
add_openvino_export_arguments(parser)
Expand Down Expand Up @@ -296,8 +293,7 @@ def instantiate_classes(self) -> None:
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self._get(self.config_init, "data")
if isinstance(self.datamodule, Dataset):
kwargs = {f"{self.config.subcommand}_dataset": self.datamodule}
self.datamodule = LightningDataModule.from_datasets(**kwargs)
self.datamodule = DataLoader(self.datamodule)
self.model = self._get(self.config_init, "model")
self._configure_optimizers_method_to_model()
self.instantiate_engine()
Expand All @@ -308,8 +304,12 @@ def instantiate_classes(self) -> None:
self.instantiate_engine()
if "model" in self.config_init[subcommand]:
self.model = self._get(self.config_init, "model")
else:
self.model = None
if "data" in self.config_init[subcommand]:
self.datamodule = self._get(self.config_init, "data")
else:
self.datamodule = None

def instantiate_engine(self) -> None:
"""Instantiate the engine.
Expand Down Expand Up @@ -479,7 +479,10 @@ def _prepare_subcommand_kwargs(self, subcommand: str) -> dict[str, Any]:
}
fn_kwargs["model"] = self.model
if self.datamodule is not None:
fn_kwargs["datamodule"] = self.datamodule
if isinstance(self.datamodule, AnomalibDataModule):
fn_kwargs["datamodule"] = self.datamodule
elif isinstance(self.datamodule, DataLoader):
fn_kwargs["dataloaders"] = self.datamodule
return fn_kwargs

def _parser(self, subcommand: str | None) -> ArgumentParser:
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .depth import DepthDataFormat, Folder3D, MVTec3D
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
from .predict import PredictDataset
from .utils import LabelName
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,4 +67,5 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
"UCSDped",
"ShanghaiTech",
"Visa",
"LabelName",
]
78 changes: 67 additions & 11 deletions src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Anomalib datamodule base class."""

# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from abc import ABC
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from lightning.pytorch import LightningDataModule
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.transforms.v2 import Resize, Transform

from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label
from anomalib.data.utils.synthetic import SyntheticAnomalyDataset
Expand Down Expand Up @@ -61,6 +62,14 @@ class AnomalibDataModule(LightningDataModule, ABC):
Defaults to ``None``.
test_split_ratio (float): Fraction of the train images held out for testing.
Defaults to ``None``.
image_size (tuple[int, int], optional): Size to which input images should be resized.
Defaults to ``None``.
transform (Transform, optional): Transforms that should be applied to the input images.
Defaults to ``None``.
train_transform (Transform, optional): Transforms that should be applied to the input images during training.
Defaults to ``None``.
eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation.
Defaults to ``None``.
seed (int | None, optional): Seed used during random subset splitting.
Defaults to ``None``.
"""
Expand All @@ -74,6 +83,10 @@ def __init__(
val_split_ratio: float,
test_split_mode: TestSplitMode | str | None = None,
test_split_ratio: float | None = None,
image_size: tuple[int, int] | None = None,
transform: Transform | None = None,
train_transform: Transform | None = None,
eval_transform: Transform | None = None,
seed: int | None = None,
) -> None:
super().__init__()
Expand All @@ -84,8 +97,13 @@ def __init__(
self.test_split_ratio = test_split_ratio
self.val_split_mode = ValSplitMode(val_split_mode)
self.val_split_ratio = val_split_ratio
self.image_size = image_size
self.seed = seed

# set transforms
self._train_transform = train_transform or transform
self._eval_transform = eval_transform or transform

self.train_data: AnomalibDataset
self.val_data: AnomalibDataset
self.test_data: AnomalibDataset
Expand All @@ -101,8 +119,11 @@ def setup(self, stage: str | None = None) -> None:
"""
if not self.is_setup:
self._setup(stage)
self._create_test_split()
self._create_val_split()
assert self.is_setup

@abstractmethod
def _setup(self, _stage: str | None = None) -> None:
"""Set up the datasets and perform dynamic subset splitting.
Expand All @@ -115,14 +136,7 @@ def _setup(self, _stage: str | None = None) -> None:
the test set must therefore be created as early as the `fit` stage.
"""
assert self.train_data is not None
assert self.test_data is not None

self.train_data.setup()
self.test_data.setup()

self._create_test_split()
self._create_val_split()
raise NotImplementedError

def _create_test_split(self) -> None:
"""Obtain the test set based on the settings in the config."""
Expand Down Expand Up @@ -176,7 +190,7 @@ def is_setup(self) -> bool:
"""
_is_setup: bool = False
for data in ("train_data", "val_data", "test_data"):
if hasattr(self, data) and getattr(self, data).is_setup:
if hasattr(self, data):
_is_setup = True

return _is_setup
Expand Down Expand Up @@ -213,3 +227,45 @@ def test_dataloader(self) -> EVAL_DATALOADERS:
def predict_dataloader(self) -> EVAL_DATALOADERS:
"""Use the test dataloader for inference unless overridden."""
return self.test_dataloader()

@property
def transform(self) -> Transform:
"""Property that returns the user-specified transform for the datamodule, if any.
This property is accessed by the engine to set the transform for the model. The eval_transform takes precedence
over the train_transform, because the transform that we store in the model is the one that should be used during
inference.
"""
if self._eval_transform:
return self._eval_transform
if self._train_transform:
return self._train_transform
return None

@property
def train_transform(self) -> Transform:
"""Get the transforms that will be passed to the train dataset.
If the train_transform is not set, the engine will request the transform from the model.
"""
if self._train_transform:
return self._train_transform
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
return self.trainer.model.transform
if self.image_size:
return Resize(self.image_size, antialias=True)
return None

@property
def eval_transform(self) -> Transform:
"""Get the transform that will be passed to the val/test/predict datasets.
If the eval_transform is not set, the engine will request the transform from the model.
"""
if self._eval_transform:
return self._eval_transform
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
return self.trainer.model.transform
if self.image_size:
return Resize(self.image_size, antialias=True)
return None
Loading

0 comments on commit 80c0756

Please sign in to comment.