From 644cb3376202d3d8f1b229a27c9533594157f1a3 Mon Sep 17 00:00:00 2001 From: Kang Wenjing Date: Thu, 27 Apr 2023 23:23:19 +0800 Subject: [PATCH] [Refactor: 644] Refactor strings and ints into enum.Enum This tries to refactor strings and ints into enum.Enum adding: 1. LabelName 2. DirType 3. DataFormat 4. AnomalyMapGenerationMode 5. ImageUpscaleMode 6. VisualizationMode Fixes #644 Signed-off-by: FanJiangIntel Signed-off-by: Kang Wenjing --- src/anomalib/data/__init__.py | 33 +++++++++---- src/anomalib/data/btech.py | 5 +- src/anomalib/data/folder.py | 28 ++++++----- src/anomalib/data/folder_3d.py | 46 ++++++++++--------- src/anomalib/data/mvtec.py | 11 +++-- src/anomalib/data/mvtec_3d.py | 9 ++-- src/anomalib/data/utils/__init__.py | 5 +- src/anomalib/data/utils/label.py | 8 ++++ src/anomalib/data/utils/path.py | 13 ++++++ .../reverse_distillation/anomaly_map.py | 31 +++++++++---- .../reverse_distillation/lightning_model.py | 3 +- .../reverse_distillation/torch_model.py | 4 +- src/anomalib/post_processing/visualizer.py | 18 ++++++-- src/anomalib/pre_processing/tiler.py | 22 ++++++--- tests/pre_merge/datasets/test_dataset.py | 12 ++--- 15 files changed, 167 insertions(+), 81 deletions(-) create mode 100644 src/anomalib/data/utils/label.py diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index c2f3be9b03..3b4b8b0a50 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +from enum import Enum from omegaconf import DictConfig, ListConfig @@ -25,6 +26,20 @@ logger = logging.getLogger(__name__) +class DataFormat(str, Enum): + """Supported Dataset Types""" + + MVTEC = "mvtec" + MVTEC_3D = "mvtec_3d" + BTECH = "btech" + FOLDER = "folder" + FOLDER_3D = "folder_3d" + UCSDPED = "ucsdped" + AVENUE = "avenue" + VISA = "visa" + SHANGHAITECH = "shanghaitech" + + def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: """Get Anomaly Datamodule. @@ -43,7 +58,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: if center_crop is not None: center_crop = (center_crop[0], center_crop[1]) - if config.dataset.format.lower() == "mvtec": + if config.dataset.format.lower() == DataFormat.MVTEC: datamodule = MVTec( root=config.dataset.path, category=config.dataset.category, @@ -61,7 +76,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "mvtec_3d": + elif config.dataset.format.lower() == DataFormat.MVTEC_3D: datamodule = MVTec3D( root=config.dataset.path, category=config.dataset.category, @@ -79,7 +94,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "btech": + elif config.dataset.format.lower() == DataFormat.BTECH: datamodule = BTech( root=config.dataset.path, category=config.dataset.category, @@ -97,7 +112,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "folder": + elif config.dataset.format.lower() == DataFormat.FOLDER: datamodule = Folder( root=config.dataset.root, normal_dir=config.dataset.normal_dir, @@ -119,7 +134,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "folder_3d": + elif config.dataset.format.lower() == DataFormat.FOLDER_3D: datamodule = Folder3D( root=config.dataset.root, normal_dir=config.dataset.normal_dir, @@ -144,7 +159,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "ucsdped": + elif config.dataset.format.lower() == DataFormat.UCSDPED: datamodule = UCSDped( root=config.dataset.path, category=config.dataset.category, @@ -162,7 +177,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "avenue": + elif config.dataset.format.lower() == DataFormat.AVENUE: datamodule = Avenue( root=config.dataset.path, gt_dir=config.dataset.gt_dir, @@ -180,7 +195,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "visa": + elif config.dataset.format.lower() == DataFormat.VISA: datamodule = Visa( root=config.dataset.path, category=config.dataset.category, @@ -198,7 +213,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: val_split_mode=config.dataset.val_split_mode, val_split_ratio=config.dataset.val_split_ratio, ) - elif config.dataset.format.lower() == "shanghaitech": + elif config.dataset.format.lower() == DataFormat.SHANGHAITECH: datamodule = ShanghaiTech( root=config.dataset.path, scene=config.dataset.scene, diff --git a/src/anomalib/data/btech.py b/src/anomalib/data/btech.py index 99b2a4484a..491441cc4c 100644 --- a/src/anomalib/data/btech.py +++ b/src/anomalib/data/btech.py @@ -26,6 +26,7 @@ from anomalib.data.utils import ( DownloadInfo, InputNormalizationMethod, + LabelName, Split, TestSplitMode, ValSplitMode, @@ -104,8 +105,8 @@ def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFram samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = "" # Create label index for normal (0) and anomalous (1) images. - samples.loc[(samples.label == "ok"), "label_index"] = 0 - samples.loc[(samples.label != "ok"), "label_index"] = 1 + samples.loc[(samples.label == "ok"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "ok"), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype(int) # Get the data frame for the split. diff --git a/src/anomalib/data/folder.py b/src/anomalib/data/folder.py index 3f782509f6..c5093c60ea 100644 --- a/src/anomalib/data/folder.py +++ b/src/anomalib/data/folder.py @@ -15,7 +15,9 @@ from anomalib.data.base import AnomalibDataModule, AnomalibDataset from anomalib.data.task_type import TaskType from anomalib.data.utils import ( + DirType, InputNormalizationMethod, + LabelName, Split, TestSplitMode, ValSplitMode, @@ -58,16 +60,16 @@ def make_folder_dataset( filenames = [] labels = [] - dirs = {"normal": normal_dir} + dirs = {DirType.NORMAL: normal_dir} if abnormal_dir: - dirs = {**dirs, **{"abnormal": abnormal_dir}} + dirs = {**dirs, **{DirType.ABNORMAL: abnormal_dir}} if normal_test_dir: - dirs = {**dirs, **{"normal_test": normal_test_dir}} + dirs = {**dirs, **{DirType.NORMAL_TEST: normal_test_dir}} if mask_dir: - dirs = {**dirs, **{"mask_dir": mask_dir}} + dirs = {**dirs, **{DirType.MASK: mask_dir}} for dir_type, path in dirs.items(): filename, label = _prepare_files_labels(path, dir_type, extensions) @@ -78,22 +80,24 @@ def make_folder_dataset( samples = samples.sort_values(by="image_path", ignore_index=True) # Create label index for normal (0) and abnormal (1) images. - samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0 - samples.loc[(samples.label == "abnormal"), "label_index"] = 1 + samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), "label_index" + ] = LabelName.NORMAL + samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype("Int64") # If a path to mask is provided, add it to the sample dataframe. if mask_dir is not None and abnormal_dir is not None: - samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[ - samples.label == "mask_dir" + samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[ + samples.label == DirType.MASK ].image_path.values samples["mask_path"].fillna("", inplace=True) samples = samples.astype({"mask_path": "str"}) # make sure all every rgb image has a corresponding mask image. assert ( - samples.loc[samples.label_index == 1] + samples.loc[samples.label_index == LabelName.ABNORMAL] .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) .all() ), "Mismatch between anomalous images and mask images. Make sure the mask files \ @@ -104,7 +108,7 @@ def make_folder_dataset( # remove all the rows with temporal image samples that have already been assigned samples = samples.loc[ - (samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test") + (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST) ] # Ensure the pathlib objects are converted to str. @@ -114,8 +118,8 @@ def make_folder_dataset( # Create train/test split. # By default, all the normal samples are assigned as train. # and all the abnormal samples are test. - samples.loc[(samples.label == "normal"), "split"] = "train" - samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test" + samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN + samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST # Get the data frame for the split. if split: diff --git a/src/anomalib/data/folder_3d.py b/src/anomalib/data/folder_3d.py index 6ddda99d3e..87dcf37e3b 100644 --- a/src/anomalib/data/folder_3d.py +++ b/src/anomalib/data/folder_3d.py @@ -16,7 +16,9 @@ from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset from anomalib.data.task_type import TaskType from anomalib.data.utils import ( + DirType, InputNormalizationMethod, + LabelName, Split, TestSplitMode, ValSplitMode, @@ -75,25 +77,25 @@ def make_folder3d_dataset( filenames = [] labels = [] - dirs = {"normal": normal_dir} + dirs = {DirType.NORMAL: normal_dir} if abnormal_dir: - dirs = {**dirs, **{"abnormal": abnormal_dir}} + dirs = {**dirs, **{DirType.ABNORMAL: abnormal_dir}} if normal_test_dir: - dirs = {**dirs, **{"normal_test": normal_test_dir}} + dirs = {**dirs, **{DirType.NORMAL_TEST: normal_test_dir}} if normal_depth_dir: - dirs = {**dirs, **{"normal_depth": normal_depth_dir}} + dirs = {**dirs, **{DirType.NORMAL_DEPTH: normal_depth_dir}} if abnormal_depth_dir: - dirs = {**dirs, **{"abnormal_depth": abnormal_depth_dir}} + dirs = {**dirs, **{DirType.ABNORMAL_DEPTH: abnormal_depth_dir}} if normal_test_depth_dir: - dirs = {**dirs, **{"normal_test_depth": normal_test_depth_dir}} + dirs = {**dirs, **{DirType.NORMAL_TEST_DEPTH: normal_test_depth_dir}} if mask_dir: - dirs = {**dirs, **{"mask_dir": mask_dir}} + dirs = {**dirs, **{DirType.MASK: mask_dir}} for dir_type, path in dirs.items(): filename, label = _prepare_files_labels(path, dir_type, extensions) @@ -104,27 +106,29 @@ def make_folder3d_dataset( samples = samples.sort_values(by="image_path", ignore_index=True) # Create label index for normal (0) and abnormal (1) images. - samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0 - samples.loc[(samples.label == "abnormal"), "label_index"] = 1 + samples.loc[ + (samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), "label_index" + ] = LabelName.NORMAL + samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype("Int64") # If a path to mask is provided, add it to the sample dataframe. if normal_depth_dir is not None: - samples.loc[samples.label == "normal", "depth_path"] = samples.loc[ - samples.label == "normal_depth" + samples.loc[samples.label == DirType.NORMAL, "depth_path"] = samples.loc[ + samples.label == DirType.NORMAL_DEPTH ].image_path.values - samples.loc[samples.label == "abnormal", "depth_path"] = samples.loc[ - samples.label == "abnormal_depth" + samples.loc[samples.label == DirType.ABNORMAL, "depth_path"] = samples.loc[ + samples.label == DirType.ABNORMAL_DEPTH ].image_path.values if normal_test_dir is not None: - samples.loc[samples.label == "normal_test", "depth_path"] = samples.loc[ - samples.label == "normal_test_depth" + samples.loc[samples.label == DirType.NORMAL_TEST, "depth_path"] = samples.loc[ + samples.label == DirType.NORMAL_TEST_DEPTH ].image_path.values # make sure every rgb image has a corresponding depth image and that the file exists assert ( - samples.loc[samples.label_index == 1] + samples.loc[samples.label_index == LabelName.ABNORMAL] .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1) .all() ), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \ @@ -139,8 +143,8 @@ def make_folder3d_dataset( # If a path to mask is provided, add it to the sample dataframe. if mask_dir is not None and abnormal_dir is not None: - samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[ - samples.label == "mask_dir" + samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[ + samples.label == DirType.MASK ].image_path.values samples["mask_path"].fillna("", inplace=True) samples = samples.astype({"mask_path": "str"}) @@ -154,7 +158,7 @@ def make_folder3d_dataset( # remove all the rows with temporal image samples that have already been assigned samples = samples.loc[ - (samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test") + (samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST) ] # Ensure the pathlib objects are converted to str. @@ -164,8 +168,8 @@ def make_folder3d_dataset( # Create train/test split. # By default, all the normal samples are assigned as train. # and all the abnormal samples are test. - samples.loc[(samples.label == "normal"), "split"] = "train" - samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test" + samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN + samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST # Get the data frame for the split. if split: diff --git a/src/anomalib/data/mvtec.py b/src/anomalib/data/mvtec.py index 36dc7d0399..f7f362bf41 100644 --- a/src/anomalib/data/mvtec.py +++ b/src/anomalib/data/mvtec.py @@ -37,6 +37,7 @@ from anomalib.data.utils import ( DownloadInfo, InputNormalizationMethod, + LabelName, Split, TestSplitMode, ValSplitMode, @@ -119,8 +120,8 @@ def make_mvtec_dataset( samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path # Create label index for normal (0) and anomalous (1) images. - samples.loc[(samples.label == "good"), "label_index"] = 0 - samples.loc[(samples.label != "good"), "label_index"] = 1 + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype(int) # separate masks from samples @@ -129,11 +130,13 @@ def make_mvtec_dataset( # assign mask paths to anomalous test images samples["mask_path"] = "" - samples.loc[(samples.split == "test") & (samples.label_index == 1), "mask_path"] = mask_samples.image_path.values + samples.loc[ + (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), "mask_path" + ] = mask_samples.image_path.values # assert that the right mask files are associated with the right test images assert ( - samples.loc[samples.label_index == 1] + samples.loc[samples.label_index == LabelName.ABNORMAL] .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) .all() ), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \ diff --git a/src/anomalib/data/mvtec_3d.py b/src/anomalib/data/mvtec_3d.py index 6cf6369202..15346f359e 100644 --- a/src/anomalib/data/mvtec_3d.py +++ b/src/anomalib/data/mvtec_3d.py @@ -34,6 +34,7 @@ from anomalib.data.utils import ( DownloadInfo, InputNormalizationMethod, + LabelName, Split, TestSplitMode, ValSplitMode, @@ -135,8 +136,8 @@ def make_mvtec_3d_dataset( ) # Create label index for normal (0) and anomalous (1) images. - samples.loc[(samples.label == "good"), "label_index"] = 0 - samples.loc[(samples.label != "good"), "label_index"] = 1 + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype(int) # separate masks from samples @@ -154,7 +155,7 @@ def make_mvtec_3d_dataset( # assert that the right mask files are associated with the right test images assert ( - samples.loc[samples.label_index == 1] + samples.loc[samples.label_index == LabelName.ABNORMAL] .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) .all() ), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \ @@ -163,7 +164,7 @@ def make_mvtec_3d_dataset( # assert that the right depth image files are associated with the right test images assert ( - samples.loc[samples.label_index == 1] + samples.loc[samples.label_index == LabelName.ABNORMAL] .apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1) .all() ), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \ diff --git a/src/anomalib/data/utils/__init__.py b/src/anomalib/data/utils/__init__.py index bbd8ac6689..25ef2f6efe 100644 --- a/src/anomalib/data/utils/__init__.py +++ b/src/anomalib/data/utils/__init__.py @@ -14,7 +14,8 @@ read_depth_image, read_image, ) -from .path import _check_and_convert_path, _prepare_files_labels, _resolve_path +from .label import LabelName +from .path import DirType, _check_and_convert_path, _prepare_files_labels, _resolve_path from .split import ( Split, TestSplitMode, @@ -38,6 +39,8 @@ "Split", "ValSplitMode", "TestSplitMode", + "LabelName", + "DirType", "Augmenter", "masks_to_boxes", "boxes_to_masks", diff --git a/src/anomalib/data/utils/label.py b/src/anomalib/data/utils/label.py new file mode 100644 index 0000000000..9e5d0a3dd0 --- /dev/null +++ b/src/anomalib/data/utils/label.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LabelName(int, Enum): + """Name of label.""" + + NORMAL = 0 + ABNORMAL = 1 diff --git a/src/anomalib/data/utils/path.py b/src/anomalib/data/utils/path.py index 8ad38da150..c498909062 100644 --- a/src/anomalib/data/utils/path.py +++ b/src/anomalib/data/utils/path.py @@ -5,11 +5,24 @@ from __future__ import annotations +from enum import Enum from pathlib import Path from torchvision.datasets.folder import IMG_EXTENSIONS +class DirType(str, Enum): + """Dir type names.""" + + NORMAL = "normal" + ABNORMAL = "abnormal" + NORMAL_TEST = "normal_test" + NORMAL_DEPTH = "normal_depth" + ABNORMAL_DEPTH = "abnormal_depth" + NORMAL_TEST_DEPTH = "normal_test_depth" + MASK = "mask_dir" + + def _check_and_convert_path(path: str | Path) -> Path: """Check an input path, and convert to Pathlib object. diff --git a/src/anomalib/models/reverse_distillation/anomaly_map.py b/src/anomalib/models/reverse_distillation/anomaly_map.py index 52e2778feb..4fddce41b5 100644 --- a/src/anomalib/models/reverse_distillation/anomaly_map.py +++ b/src/anomalib/models/reverse_distillation/anomaly_map.py @@ -11,6 +11,8 @@ from __future__ import annotations +from enum import Enum + import torch import torch.nn.functional as F from kornia.filters import gaussian_blur2d @@ -18,26 +20,39 @@ from torch import Tensor, nn +class AnomalyMapGenerationMode(str, Enum): + """Type of mode when generating anomaly imape.""" + + ADD = "add" + MULTIPLY = "multiply" + + class AnomalyMapGenerator(nn.Module): """Generate Anomaly Heatmap. Args: image_size (ListConfig, tuple): Size of original image used for upscaling the anomaly map. sigma (int): Standard deviation of the gaussian kernel used to smooth anomaly map. - mode (str, optional): Operation used to generate anomaly map. Options are `add` and `multiply`. - Defaults to "multiply". + mode (AnomalyMapGenerationMode, optional): Operation used to generate anomaly map. + Options are `AnomalyMapGenerationMode.ADD` and `AnomalyMapGenerationMode.MULTIPLY`. + Defaults to "AnomalyMapGenerationMode.MULTIPLY". Raises: ValueError: In case modes other than multiply and add are passed. """ - def __init__(self, image_size: ListConfig | tuple, sigma: int = 4, mode: str = "multiply") -> None: + def __init__( + self, + image_size: ListConfig | tuple, + sigma: int = 4, + mode: AnomalyMapGenerationMode = AnomalyMapGenerationMode.MULTIPLY, + ) -> None: super().__init__() self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size) self.sigma = sigma self.kernel_size = 2 * int(4.0 * sigma + 0.5) + 1 - if mode not in ("add", "multiply"): + if mode not in (AnomalyMapGenerationMode.ADD, AnomalyMapGenerationMode.MULTIPLY): raise ValueError(f"Found mode {mode}. Only multiply and add are supported.") self.mode = mode @@ -51,11 +66,11 @@ def forward(self, student_features: list[Tensor], teacher_features: list[Tensor] Returns: Tensor: Anomaly maps of length batch. """ - if self.mode == "multiply": + if self.mode == AnomalyMapGenerationMode.MULTIPLY: anomaly_map = torch.ones( [student_features[0].shape[0], 1, *self.image_size], device=student_features[0].device ) # b c h w - elif self.mode == "add": + elif self.mode == AnomalyMapGenerationMode.ADD: anomaly_map = torch.zeros( [student_features[0].shape[0], 1, *self.image_size], device=student_features[0].device ) @@ -64,9 +79,9 @@ def forward(self, student_features: list[Tensor], teacher_features: list[Tensor] distance_map = 1 - F.cosine_similarity(student_feature, teacher_feature) distance_map = torch.unsqueeze(distance_map, dim=1) distance_map = F.interpolate(distance_map, size=self.image_size, mode="bilinear", align_corners=True) - if self.mode == "multiply": + if self.mode == AnomalyMapGenerationMode.MULTIPLY: anomaly_map *= distance_map - elif self.mode == "add": + elif self.mode == AnomalyMapGenerationMode.ADD: anomaly_map += distance_map anomaly_map = gaussian_blur2d( diff --git a/src/anomalib/models/reverse_distillation/lightning_model.py b/src/anomalib/models/reverse_distillation/lightning_model.py index 29a2f3dc31..5489daab5b 100644 --- a/src/anomalib/models/reverse_distillation/lightning_model.py +++ b/src/anomalib/models/reverse_distillation/lightning_model.py @@ -15,6 +15,7 @@ from anomalib.models.components import AnomalyModule +from .anomaly_map import AnomalyMapGenerationMode from .loss import ReverseDistillationLoss from .torch_model import ReverseDistillationModel @@ -34,7 +35,7 @@ def __init__( input_size: tuple[int, int], backbone: str, layers: list[str], - anomaly_map_mode: str, + anomaly_map_mode: AnomalyMapGenerationMode, lr: float, beta1: float, beta2: float, diff --git a/src/anomalib/models/reverse_distillation/torch_model.py b/src/anomalib/models/reverse_distillation/torch_model.py index 576bf56a86..5ada84573e 100644 --- a/src/anomalib/models/reverse_distillation/torch_model.py +++ b/src/anomalib/models/reverse_distillation/torch_model.py @@ -15,6 +15,8 @@ ) from anomalib.pre_processing import Tiler +from .anomaly_map import AnomalyMapGenerationMode + class ReverseDistillationModel(nn.Module): """Reverse Distillation Model. @@ -32,7 +34,7 @@ def __init__( backbone: str, input_size: tuple[int, int], layers: list[str], - anomaly_map_mode: str, + anomaly_map_mode: AnomalyMapGenerationMode, pre_trained: bool = True, ) -> None: super().__init__() diff --git a/src/anomalib/post_processing/visualizer.py b/src/anomalib/post_processing/visualizer.py index 1137da0ae0..05050b1c02 100644 --- a/src/anomalib/post_processing/visualizer.py +++ b/src/anomalib/post_processing/visualizer.py @@ -6,6 +6,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from typing import Iterator @@ -61,16 +62,23 @@ def __post_init__(self) -> None: self.anomalous_boxes = self.pred_boxes[self.box_labels.astype(bool)] +class VisualizationMode(str, Enum): + """Type of visualization mode.""" + + FULL = "full" + SIMPLE = "simple" + + class Visualizer: """Class that handles the logic of composing the visualizations. Args: - mode (str): visualization mode, either "full" or "simple" + mode (VisualizationMode): visualization mode, either "full" or "simple" task (TaskType): task type "segmentation", "detection" or "classification" """ - def __init__(self, mode: str, task: TaskType) -> None: - if mode not in ("full", "simple"): + def __init__(self, mode: VisualizationMode, task: TaskType) -> None: + if mode not in (VisualizationMode.FULL, VisualizationMode.SIMPLE): raise ValueError(f"Unknown visualization mode: {mode}. Please choose one of ['full', 'simple']") self.mode = mode if task not in (TaskType.CLASSIFICATION, TaskType.DETECTION, TaskType.SEGMENTATION): @@ -122,9 +130,9 @@ def visualize_image(self, image_result: ImageResult) -> np.ndarray: Returns: The full or simple visualization for the image, depending on the specified mode. """ - if self.mode == "full": + if self.mode == VisualizationMode.FULL: return self._visualize_full(image_result) - if self.mode == "simple": + if self.mode == VisualizationMode.SIMPLE: return self._visualize_simple(image_result) raise ValueError(f"Unknown visualization mode: {self.mode}") diff --git a/src/anomalib/pre_processing/tiler.py b/src/anomalib/pre_processing/tiler.py index 120392765a..fb9e807579 100644 --- a/src/anomalib/pre_processing/tiler.py +++ b/src/anomalib/pre_processing/tiler.py @@ -5,6 +5,7 @@ from __future__ import annotations +from enum import Enum from itertools import product from math import ceil from typing import Sequence @@ -15,6 +16,13 @@ from torch.nn import functional as F +class ImageUpscaleMode(str, Enum): + """Type of mode when upscaling image.""" + + PADDING = "padding" + INTERPOLATION = "interpolation" + + class StrideSizeError(Exception): """StrideSizeError to raise exception when stride size is greater than the tile size.""" @@ -53,7 +61,7 @@ def __compute_new_edge_size(edge_size: int, tile_size: int, stride: int) -> int: return resized_h, resized_w -def upscale_image(image: Tensor, size: tuple, mode: str = "padding") -> Tensor: +def upscale_image(image: Tensor, size: tuple, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING) -> Tensor: """Upscale image to the desired size via either padding or interpolation. Args: @@ -79,12 +87,12 @@ def upscale_image(image: Tensor, size: tuple, mode: str = "padding") -> Tensor: image_h, image_w = image.shape[2:] resize_h, resize_w = size - if mode == "padding": + if mode == ImageUpscaleMode.PADDING: pad_h = resize_h - image_h pad_w = resize_w - image_w image = F.pad(image, [0, pad_w, 0, pad_h]) - elif mode == "interpolation": + elif mode == ImageUpscaleMode.INTERPOLATION: image = F.interpolate(input=image, size=(resize_h, resize_w)) else: raise ValueError(f"Unknown mode {mode}. Only padding and interpolation is available.") @@ -92,7 +100,7 @@ def upscale_image(image: Tensor, size: tuple, mode: str = "padding") -> Tensor: return image -def downscale_image(image: Tensor, size: tuple, mode: str = "padding") -> Tensor: +def downscale_image(image: Tensor, size: tuple, mode: ImageUpscaleMode = ImageUpscaleMode.PADDING) -> Tensor: """Opposite of upscaling. This image downscales image to a desired size. Args: @@ -111,7 +119,7 @@ def downscale_image(image: Tensor, size: tuple, mode: str = "padding") -> Tensor Tensor: Downscaled image """ input_h, input_w = size - if mode == "padding": + if mode == ImageUpscaleMode.PADDING: image = image[:, :, :input_h, :input_w] else: image = F.interpolate(input=image, size=(input_h, input_w)) @@ -151,7 +159,7 @@ def __init__( tile_size: int | Sequence, stride: int | Sequence | None = None, remove_border_count: int = 0, - mode: str = "padding", + mode: ImageUpscaleMode = ImageUpscaleMode.PADDING, tile_count: int = 4, ) -> None: self.tile_size_h, self.tile_size_w = self.__validate_size_type(tile_size) @@ -170,7 +178,7 @@ def __init__( "Please ensure stride size is less than or equal than tiling size." ) - if self.mode not in ("padding", "interpolation"): + if self.mode not in (ImageUpscaleMode.PADDING, ImageUpscaleMode.INTERPOLATION): raise ValueError(f"Unknown tiling mode {self.mode}. Available modes are padding and interpolation") self.batch_size: int diff --git a/tests/pre_merge/datasets/test_dataset.py b/tests/pre_merge/datasets/test_dataset.py index 330fdcfa3f..1ee97dba20 100644 --- a/tests/pre_merge/datasets/test_dataset.py +++ b/tests/pre_merge/datasets/test_dataset.py @@ -7,7 +7,7 @@ from anomalib.data import TaskType from anomalib.data.folder import FolderDataset -from anomalib.data.utils import get_transforms +from anomalib.data.utils import get_transforms, LabelName from anomalib.data.utils.split import concatenate_datasets, random_split from tests.helpers.dataset import get_dataset_path @@ -62,14 +62,14 @@ def test_random_split(self, folder_dataset): # label-aware subset splitting samples = folder_dataset.samples - normal_samples = samples[samples["label_index"] == 0] - anomalous_samples = samples[samples["label_index"] == 1] + normal_samples = samples[samples["label_index"] == LabelName.NORMAL] + anomalous_samples = samples[samples["label_index"] == LabelName.ABNORMAL] samples = pd.concat([normal_samples, anomalous_samples[0:5]]) folder_dataset.samples = samples subsets = random_split(folder_dataset, [0.4, 0.4, 0.2], label_aware=True) # 5 anomalous images in total, so the first two subsets should each have 2, and the last subset 1 - assert len(subsets[0].samples[subsets[0].samples["label_index"] == 1]) == 2 - assert len(subsets[1].samples[subsets[1].samples["label_index"] == 1]) == 2 - assert len(subsets[2].samples[subsets[2].samples["label_index"] == 1]) == 1 + assert len(subsets[0].samples[subsets[0].samples["label_index"] == LabelName.ABNORMAL]) == 2 + assert len(subsets[1].samples[subsets[1].samples["label_index"] == LabelName.ABNORMAL]) == 2 + assert len(subsets[2].samples[subsets[2].samples["label_index"] == LabelName.ABNORMAL]) == 1