diff --git a/configs/data/folder.yaml b/configs/data/folder.yaml index 3a5f648947..9d9abee078 100644 --- a/configs/data/folder.yaml +++ b/configs/data/folder.yaml @@ -1,5 +1,6 @@ class_path: anomalib.data.Folder init_args: + name: bottle root: "datasets/MVTec/bottle" normal_dir: "train/good" abnormal_dir: "test/broken_large" diff --git a/notebooks/100_datamodules/103_folder.ipynb b/notebooks/100_datamodules/103_folder.ipynb index 8aa8e86c12..f870606175 100644 --- a/notebooks/100_datamodules/103_folder.ipynb +++ b/notebooks/100_datamodules/103_folder.ipynb @@ -96,6 +96,7 @@ "outputs": [], "source": [ "folder_datamodule = Folder(\n", + " name=\"hazelnut_toy\",\n", " root=dataset_root,\n", " normal_dir=\"good\",\n", " abnormal_dir=\"crack\",\n", @@ -331,6 +332,7 @@ ], "source": [ "folder_dataset_classification_train = FolderDataset(\n", + " name=\"hazelnut_toy\",\n", " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"train\",\n", @@ -476,6 +478,7 @@ "source": [ "# Folder Classification Test Set\n", "folder_dataset_classification_test = FolderDataset(\n", + " name=\"hazelnut_toy\",\n", " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"test\",\n", @@ -615,6 +618,7 @@ "source": [ "# Folder Segmentation Train Set\n", "folder_dataset_segmentation_train = FolderDataset(\n", + " name=\"hazelnut_toy\",\n", " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"train\",\n", @@ -727,6 +731,7 @@ "source": [ "# Folder Segmentation Test Set\n", "folder_dataset_segmentation_test = FolderDataset(\n", + " name=\"hazelnut_toy\",\n", " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"test\",\n", diff --git a/src/anomalib/callbacks/__init__.py b/src/anomalib/callbacks/__init__.py index eccc3cabc9..12ec54d8f3 100644 --- a/src/anomalib/callbacks/__init__.py +++ b/src/anomalib/callbacks/__init__.py @@ -20,6 +20,7 @@ from .timer import TimerCallback __all__ = [ + "ModelCheckpoint", "GraphLogger", "LoadModelCallback", "TilerConfigurationCallback", @@ -43,21 +44,6 @@ def get_callbacks(config: DictConfig | ListConfig | Namespace) -> list[Callback] callbacks: list[Callback] = [] - monitor_metric = ( - None if "early_stopping" not in config.model.init_args else config.model.init_args.early_stopping.metric - ) - monitor_mode = "max" if "early_stopping" not in config.model.init_args else config.model.early_stopping.mode - - checkpoint = ModelCheckpoint( - dirpath=Path(config.trainer.default_root_dir) / "weights" / "lightning", - filename="model", - monitor=monitor_metric, - mode=monitor_mode, - auto_insert_metric_name=False, - ) - - callbacks.extend([checkpoint, TimerCallback()]) - if "ckpt_path" in config.trainer and config.ckpt_path is not None: load_model = LoadModelCallback(config.ckpt_path) callbacks.append(load_model) diff --git a/src/anomalib/callbacks/visualizer.py b/src/anomalib/callbacks/visualizer.py index 6afe9d2cfb..c78c1f6ab8 100644 --- a/src/anomalib/callbacks/visualizer.py +++ b/src/anomalib/callbacks/visualizer.py @@ -97,7 +97,18 @@ def on_test_batch_end( if result.file_name is None: msg = "``save`` is set to ``True`` but file name is ``None``" raise ValueError(msg) - save_image(image=result.image, root=self.root, filename=result.file_name) + + # Get the filename to save the image. + # Filename is split based on the datamodule name and category. + # For example, if the filename is `MVTec/bottle/000.png`, then the + # filename is split based on `MVTec/bottle` and `000.png` is saved. + if trainer.datamodule is not None: + filename = str(result.file_name).split( + sep=f"{trainer.datamodule.name}/{trainer.datamodule.category}", + )[-1] + else: + filename = Path(result.file_name).name + save_image(image=result.image, root=self.root, filename=filename) if self.show: show_image(image=result.image, title=str(result.file_name)) if self.log: diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 427260c59b..9de4a1d9c1 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -34,7 +34,6 @@ from anomalib.metrics.threshold import BaseThreshold from anomalib.models import AnomalyModule from anomalib.utils.config import update_config - from anomalib.utils.visualization.base import BaseVisualizer except ImportError: _LIGHTNING_AVAILABLE = False @@ -143,15 +142,6 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: from anomalib.callbacks.normalization import get_normalization_callback parser.add_function_arguments(get_normalization_callback, "normalization") - # visualization takes task from the project - parser.add_argument( - "--visualization.visualizers", - type=BaseVisualizer | list[BaseVisualizer] | None, - default=None, - ) - parser.add_argument("--visualization.save", type=bool, default=False) - parser.add_argument("--visualization.log", type=bool, default=False) - parser.add_argument("--visualization.show", type=bool, default=False) parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) @@ -160,13 +150,14 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"): parser.link_arguments("task", "data.init_args.task") parser.add_argument( - "--results_dir.path", + "--default_root_dir", type=Path, help="Path to save the results.", default=Path("./results"), ) - parser.add_argument("--results_dir.unique", type=bool, help="Whether to create a unique folder.", default=False) - parser.link_arguments("results_dir.path", "trainer.default_root_dir") + parser.link_arguments("default_root_dir", "trainer.default_root_dir") + # TODO(ashwinvaidya17): Tiling should also be a category of its own + # CVS-122659 def add_trainer_arguments(self, parser: ArgumentParser, subcommand: str) -> None: """Add train arguments to the parser.""" @@ -329,7 +320,6 @@ def instantiate_engine(self) -> None: "task": self._get(self.config_init, "task"), "image_metrics": self._get(self.config_init, "metrics.image"), "pixel_metrics": self._get(self.config_init, "metrics.pixel"), - **self._get_visualization_parameters(), } trainer_config = {**self._get(self.config_init, "trainer", default={}), **engine_args} key = "callbacks" @@ -348,16 +338,6 @@ def instantiate_engine(self) -> None: trainer_config[key].extend(get_callbacks(self.config[self.subcommand])) self.engine = Engine(**trainer_config) - def _get_visualization_parameters(self) -> dict[str, Any]: - """Return visualization parameters.""" - subcommand = self.config.subcommand - return { - "visualizers": self.config_init[subcommand].visualization.visualizers, - "save_image": self.config[subcommand].visualization.save, - "log_image": self.config[subcommand].visualization.log, - "show_image": self.config[subcommand].visualization.show, - } - def _run_subcommand(self) -> None: """Run subcommand depending on the subcommand. diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 17a7c6b737..5f4b49cf9e 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -110,9 +110,15 @@ def __init__( self.test_data: AnomalibDataset self._samples: DataFrame | None = None + self._category: str = "" self._is_setup = False # flag to track if setup has been called from the trainer + @property + def name(self) -> str: + """Name of the datamodule.""" + return self.__class__.__name__ + def setup(self, stage: str | None = None) -> None: """Set up train, validation and test data. @@ -144,6 +150,16 @@ def _setup(self, _stage: str | None = None) -> None: """ raise NotImplementedError + @property + def category(self) -> str: + """Get the category of the datamodule.""" + return self._category + + @category.setter + def category(self, category: str) -> None: + """Set the category of the datamodule.""" + self._category = category + def _create_test_split(self) -> None: """Obtain the test set based on the settings in the config.""" if self.test_data.has_normal: diff --git a/src/anomalib/data/base/dataset.py b/src/anomalib/data/base/dataset.py index 493db18e72..0635782b3f 100644 --- a/src/anomalib/data/base/dataset.py +++ b/src/anomalib/data/base/dataset.py @@ -67,6 +67,18 @@ def __init__(self, task: TaskType, transform: Transform | None = None) -> None: self.task = task self.transform = transform self._samples: DataFrame | None = None + self._category: str | None = None + + @property + def name(self) -> str: + """Name of the dataset.""" + class_name = self.__class__.__name__ + + # Remove the `_dataset` suffix from the class name + if class_name.endswith("Dataset"): + class_name = class_name[:-7] + + return class_name def __len__(self) -> int: """Get length of the dataset.""" @@ -113,6 +125,16 @@ def samples(self, samples: DataFrame) -> None: self._samples = samples.sort_values(by="image_path", ignore_index=True) + @property + def category(self) -> str | None: + """Get the category of the dataset.""" + return self._category + + @category.setter + def category(self, category: str) -> None: + """Set the category of the dataset.""" + self._category = category + @property def has_normal(self) -> bool: """Check if the dataset contains any normal samples.""" diff --git a/src/anomalib/data/depth/folder_3d.py b/src/anomalib/data/depth/folder_3d.py index 78ef9be701..a1fab24591 100644 --- a/src/anomalib/data/depth/folder_3d.py +++ b/src/anomalib/data/depth/folder_3d.py @@ -185,6 +185,7 @@ class Folder3DDataset(AnomalibDepthDataset): """Folder dataset. Args: + name (str): Name of the dataset. task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``). transform (Transform): Transforms that should be applied to the input images. normal_dir (str | Path): Path to the directory containing normal images. @@ -222,6 +223,7 @@ class Folder3DDataset(AnomalibDepthDataset): def __init__( self, + name: str, task: TaskType, normal_dir: str | Path, root: str | Path | None = None, @@ -237,6 +239,7 @@ def __init__( ) -> None: super().__init__(task, transform) + self._name = name self.split = split self.root = root self.normal_dir = normal_dir @@ -261,11 +264,20 @@ def __init__( extensions=self.extensions, ) + @property + def name(self) -> str: + """Name of the dataset. + + Folder3D dataset overrides the name property to provide a custom name. + """ + return self._name + class Folder3D(AnomalibDataModule): """Folder DataModule. Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. normal_dir (str | Path): Name of the directory containing normal images. root (str | Path | None): Path to the root folder containing normal and abnormal dirs. Defaults to ``None``. @@ -318,6 +330,7 @@ class Folder3D(AnomalibDataModule): def __init__( self, + name: str, normal_dir: str | Path, root: str | Path, abnormal_dir: str | Path | None = None, @@ -355,6 +368,7 @@ def __init__( val_split_ratio=val_split_ratio, seed=seed, ) + self._name = name self.task = TaskType(task) self.root = Path(root) self.normal_dir = normal_dir @@ -368,6 +382,7 @@ def __init__( def _setup(self, _stage: str | None = None) -> None: self.train_data = Folder3DDataset( + name=self.name, task=self.task, transform=self.train_transform, split=Split.TRAIN, @@ -383,6 +398,7 @@ def _setup(self, _stage: str | None = None) -> None: ) self.test_data = Folder3DDataset( + name=self.name, task=self.task, transform=self.eval_transform, split=Split.TEST, @@ -396,3 +412,11 @@ def _setup(self, _stage: str | None = None) -> None: mask_dir=self.mask_dir, extensions=self.extensions, ) + + @property + def name(self) -> str: + """Name of the datamodule. + + Folder3D datamodule overrides the name property to provide a custom name. + """ + return self._name diff --git a/src/anomalib/data/image/folder.py b/src/anomalib/data/image/folder.py index 1266a490f8..4b52bd691a 100644 --- a/src/anomalib/data/image/folder.py +++ b/src/anomalib/data/image/folder.py @@ -183,6 +183,7 @@ class FolderDataset(AnomalibDataset): This class is used to create a dataset from a folder. The class utilizes the Torch Dataset class. Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``). transform (Transform, optional): Transforms that should be applied to the input images. Defaults to ``None``. @@ -230,6 +231,7 @@ class FolderDataset(AnomalibDataset): def __init__( self, + name: str, task: TaskType, normal_dir: str | Path | Sequence[str | Path], transform: Transform | None = None, @@ -242,6 +244,7 @@ def __init__( ) -> None: super().__init__(task, transform) + self._name = name self.split = split self.root = root self.normal_dir = normal_dir @@ -260,11 +263,20 @@ def __init__( extensions=self.extensions, ) + @property + def name(self) -> str: + """Name of the dataset. + + Folder dataset overrides the name property to provide a custom name. + """ + return self._name + class Folder(AnomalibDataModule): """Folder DataModule. Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. normal_dir (str | Path | Sequence): Name of the directory containing normal images. root (str | Path | None): Path to the root folder containing normal and abnormal dirs. Defaults to ``None``. @@ -367,6 +379,7 @@ class Folder(AnomalibDataModule): def __init__( self, + name: str, normal_dir: str | Path | Sequence[str | Path], root: str | Path | None = None, abnormal_dir: str | Path | Sequence[str | Path] | None = None, @@ -388,6 +401,7 @@ def __init__( val_split_ratio: float = 0.5, seed: int | None = None, ) -> None: + self._name = name self.root = root self.normal_dir = normal_dir self.abnormal_dir = abnormal_dir @@ -425,6 +439,7 @@ def __init__( def _setup(self, _stage: str | None = None) -> None: self.train_data = FolderDataset( + name=self.name, task=self.task, transform=self.train_transform, split=Split.TRAIN, @@ -437,6 +452,7 @@ def _setup(self, _stage: str | None = None) -> None: ) self.test_data = FolderDataset( + name=self.name, task=self.task, transform=self.eval_transform, split=Split.TEST, @@ -447,3 +463,11 @@ def _setup(self, _stage: str | None = None) -> None: mask_dir=self.mask_dir, extensions=self.extensions, ) + + @property + def name(self) -> str: + """Name of the datamodule. + + Folder datamodule overrides the name property to provide a custom name. + """ + return self._name diff --git a/src/anomalib/data/image/mvtec.py b/src/anomalib/data/image/mvtec.py index 3e06c72c71..c23add93ab 100644 --- a/src/anomalib/data/image/mvtec.py +++ b/src/anomalib/data/image/mvtec.py @@ -229,6 +229,7 @@ def __init__( super().__init__(task=task, transform=transform) self.root_category = Path(root) / Path(category) + self.category = category self.split = split self.samples = make_mvtec_dataset(self.root_category, split=self.split, extensions=IMG_EXTENSIONS) diff --git a/src/anomalib/data/utils/image.py b/src/anomalib/data/utils/image.py index 04369f8a25..8e33951776 100644 --- a/src/anomalib/data/utils/image.py +++ b/src/anomalib/data/utils/image.py @@ -445,6 +445,9 @@ def save_image(filename: Path | str, image: np.ndarray | Figure, root: Path | No if root: file_path = root / file_path + # Make unique file_path if file already exists + file_path = duplicate_filename(file_path) + file_path.parent.mkdir(parents=True, exist_ok=True) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) cv2.imwrite(str(file_path), image) diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 181ca4fcc9..2f2f14279b 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -4,30 +4,34 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from collections.abc import Iterable from pathlib import Path from typing import Any import torch from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loggers import Logger from lightning.pytorch.trainer import Trainer -from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader, Dataset from torchvision.transforms.v2 import Transform from anomalib import LearningType, TaskType +from anomalib.callbacks.checkpoint import ModelCheckpoint from anomalib.callbacks.metrics import _MetricsCallback from anomalib.callbacks.normalization import get_normalization_callback from anomalib.callbacks.normalization.base import NormalizationCallback from anomalib.callbacks.post_processor import _PostProcessorCallback from anomalib.callbacks.thresholding import _ThresholdCallback +from anomalib.callbacks.timer import TimerCallback from anomalib.callbacks.visualizer import _VisualizationCallback from anomalib.data import AnomalibDataModule, AnomalibDataset, PredictDataset from anomalib.deploy.export import ExportType, export_to_onnx, export_to_openvino, export_to_torch from anomalib.models import AnomalyModule from anomalib.utils.normalization import NormalizationMethod +from anomalib.utils.path import create_versioned_dir from anomalib.utils.types import NORMALIZATION, THRESHOLD -from anomalib.utils.visualization import BaseVisualizer +from anomalib.utils.visualization import ImageVisualizer logger = logging.getLogger(__name__) @@ -83,10 +87,7 @@ def update(self, model: AnomalyModule) -> None: self._cached_args[key] = value def requires_update(self, model: AnomalyModule) -> bool: - for key, value in model.trainer_arguments.items(): - if key in self._cached_args and self._cached_args[key] != value: - return True - return False + return any(self._cached_args.get(key, None) != value for key, value in model.trainer_arguments.items()) @property def args(self) -> dict[str, Any]: @@ -112,8 +113,9 @@ class Engine: Defaults to None. pixel_metrics (str | list[str] | None, optional): Pixel metrics to be used for evaluation. Defaults to None. - visualizers (BaseVisualizationGenerator | list[BaseVisualizationGenerator] | None): - Visualization parameters. Defaults to None. + default_root_dir (str, optional): Default root directory for the trainer. + The results will be saved in this directory. + Defaults to ``results``. **kwargs: PyTorch Lightning Trainer arguments. """ @@ -125,10 +127,8 @@ def __init__( task: TaskType | str = TaskType.SEGMENTATION, image_metrics: str | list[str] | None = None, pixel_metrics: str | list[str] | None = None, - visualizers: BaseVisualizer | list[BaseVisualizer] | None = None, - save_image: bool = False, - log_image: bool = False, - show_image: bool = False, + logger: Logger | Iterable[Logger] | bool | None = None, + default_root_dir: str | Path = "results", **kwargs, ) -> None: # TODO(ashwinvaidya17): Add model argument to engine constructor @@ -136,7 +136,15 @@ def __init__( if callbacks is None: callbacks = [] - self._cache = _TrainerArgumentsCache(callbacks=[*callbacks], **kwargs) + # Cache the Lightning Trainer arguments. + logger = False if logger is None else logger + self._cache = _TrainerArgumentsCache( + callbacks=[*callbacks], + logger=logger, + default_root_dir=Path(default_root_dir), + **kwargs, + ) + self.normalization = normalization self.threshold = threshold self.task = TaskType(task) @@ -147,12 +155,6 @@ def __init__( if self.task == TaskType.SEGMENTATION: self.pixel_metric_names = pixel_metrics if pixel_metrics is not None else ["AUROC", "F1Score"] - self.visualizers = visualizers - - self.save_image = save_image - self.log_image = log_image - self.show_image = show_image - self._trainer: Trainer | None = None @property @@ -170,33 +172,6 @@ def trainer(self) -> Trainer: raise UnassignedError(msg) return self._trainer - @property - def visualizers(self) -> BaseVisualizer | list[BaseVisualizer] | None: - """Get visualization generators.""" - return self._visualizers - - @visualizers.setter - def visualizers(self, visualizers: BaseVisualizer | list[BaseVisualizer] | None) -> None: - """Set the visualizers. - - Args: - visualizers (BaseVisualizer | list[BaseVisualizer] | None): Visualizers to be used for visualization. - """ - self._visualizers = visualizers - # override the task in the visualizers if it is not the same as the task of the engine - if self.visualizers: - visualizers = ( - self.visualizers - if isinstance(self.visualizers, list) - else [ - self.visualizers, - ] - ) - for visualizer in visualizers: - if hasattr(visualizer, "task") and visualizer.task != self.task: - logger.info(f"Overriding task of {visualizer} to {self.task}") - visualizer.task = self.task - @property def model(self) -> AnomalyModule: """Property to get the model. @@ -250,15 +225,85 @@ def threshold_callback(self) -> _ThresholdCallback | None: raise ValueError(msg) return callbacks[0] if len(callbacks) > 0 else None + def _setup_workspace( + self, + model: AnomalyModule, + train_dataloaders: TRAIN_DATALOADERS | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + test_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: AnomalibDataModule | None = None, + dataset: AnomalibDataset | None = None, + versioned_dir: bool = False, + ) -> None: + """Setup the workspace for the model. + + This method sets up the default root directory for the model based on + the model name, dataset name, and category. Model checkpoints, logs, and + other artifacts will be saved in this directory. + + Args: + model (AnomalyModule): Input model. + train_dataloaders (TRAIN_DATALOADERS | None, optional): Train dataloaders. + Defaults to ``None``. + val_dataloaders (EVAL_DATALOADERS | None, optional): Validation dataloaders. + Defaults to ``None``. + test_dataloaders (EVAL_DATALOADERS | None, optional): Test dataloaders. + Defaults to ``None``. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Defaults to ``None``. + dataset (AnomalibDataset | None, optional): Anomalib dataset. + Defaults to ``None``. + versioned_dir (bool, optional): Whether to create a versioned directory. + Defaults to ``True``. + + Raises: + TypeError: If the dataloader type is unknown. + """ + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # 1. Get the dataset name and category from the dataloaders, datamodule, or dataset. + dataset_name: str = "" + category: str | None = None + + # Check datamodule and dataset directly + if datamodule is not None: + dataset_name = datamodule.name + category = datamodule.category + elif dataset is not None: + dataset_name = dataset.name + category = dataset.category + + # Check dataloaders if dataset_name and category are not set + dataloaders = [train_dataloaders, val_dataloaders, test_dataloaders] + if not dataset_name or category is None: + for dataloader in dataloaders: + if dataloader is not None: + if hasattr(dataloader, "train_data"): + dataset_name = getattr(dataloader.train_data, "name", "") + category = getattr(dataloader.train_data, "category", "") + break + if dataset_name and category is not None: + break + + # Check if category is None and set it to empty string + category = category if category is not None else "" + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # 2. Update the default root directory + root_dir = Path(self._cache.args["default_root_dir"]) / model.name / dataset_name / category + self._cache.args["default_root_dir"] = create_versioned_dir(root_dir) if versioned_dir else root_dir / "latest" + def _setup_trainer(self, model: AnomalyModule) -> None: """Instantiate the trainer based on the model parameters.""" - if self._cache.requires_update(model) or self._trainer is None: + # Check if the cache requires an update + if self._cache.requires_update(model): self._cache.update(model) + + # Setup anomalib callbacks to be used with the trainer + self._setup_anomalib_callbacks() + + # Instantiate the trainer if it is not already instantiated + if self._trainer is None: self._trainer = Trainer(**self._cache.args) - # Callbacks need to be setup later as they depend on default_root_dir from the trainer - # TODO(djdameln): set up callbacks before instantiating trainer - # https://github.com/openvinotoolkit/anomalib/issues/1642 - self._setup_anomalib_callbacks() def _setup_dataset_task( self, @@ -336,36 +381,50 @@ def _setup_transform( def _setup_anomalib_callbacks(self) -> None: """Set up callbacks for the trainer.""" - _callbacks: list[Callback] = [_PostProcessorCallback()] + _callbacks: list[Callback] = [] + + # Add ModelCheckpoint if it is not in the callbacks list. + has_checkpoint_callback = any(isinstance(c, ModelCheckpoint) for c in self._cache.args["callbacks"]) + if has_checkpoint_callback is False: + _callbacks.append( + ModelCheckpoint( + dirpath=self._cache.args["default_root_dir"] / "weights" / "lightning", + filename="model", + auto_insert_metric_name=False, + ), + ) + + # Add the post-processor callbacks. + _callbacks.append(_PostProcessorCallback()) + + # Add the the normalization callback. normalization_callback = get_normalization_callback(self.normalization) if normalization_callback is not None: _callbacks.append(normalization_callback) + # Add the thresholding and metrics callbacks. _callbacks.append(_ThresholdCallback(self.threshold)) _callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names)) - if self.visualizers is not None: - image_save_path = Path(self.trainer.default_root_dir) / "images" - _callbacks.append( - _VisualizationCallback( - visualizers=self.visualizers, - save=self.save_image, - root=image_save_path, - log=self.log_image, - show=self.show_image, - ), - ) - - self.trainer.callbacks = _CallbackConnector._reorder_callbacks( # noqa: SLF001 - self.trainer.callbacks + _callbacks, + _callbacks.append( + _VisualizationCallback( + visualizers=ImageVisualizer(task=self.task), + save=True, + root=self._cache.args["default_root_dir"] / "images", + ), ) + _callbacks.append(TimerCallback()) + + # Combine the callbacks, and update the trainer callbacks. + self._cache.args["callbacks"] = _callbacks + self._cache.args["callbacks"] + def _should_run_validation( self, model: AnomalyModule, dataloaders: EVAL_DATALOADERS | None, datamodule: AnomalibDataModule | None, - ckpt_path: str | None, + ckpt_path: str | Path | None, ) -> bool: """Check if we need to run validation to collect normalization statistics and thresholds. @@ -383,7 +442,7 @@ def _should_run_validation( model (AnomalyModule): Model passed to the entrypoint. dataloaders (EVAL_DATALOADERS | None): Dataloaders passed to the entrypoint. datamodule (AnomalibDataModule | None): Lightning datamodule passed to the entrypoint. - ckpt_path (str | None): Checkpoint path passed to the entrypoint. + ckpt_path (str | Path | None): Checkpoint path passed to the entrypoint. Returns: bool: Whether it is needed to run a validation sequence. @@ -406,7 +465,7 @@ def fit( train_dataloaders: TRAIN_DATALOADERS | None = None, val_dataloaders: EVAL_DATALOADERS | None = None, datamodule: AnomalibDataModule | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, ) -> None: """Fit the model using the trainer. @@ -436,6 +495,16 @@ def fit( anomalib fit --config ``` """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace( + model=model, + train_dataloaders=train_dataloaders, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + versioned_dir=True, + ) self._setup_trainer(model) self._setup_dataset_task(train_dataloaders, val_dataloaders, datamodule) self._setup_transform(model, datamodule=datamodule, ckpt_path=ckpt_path) @@ -449,7 +518,7 @@ def validate( self, model: AnomalyModule | None = None, dataloaders: EVAL_DATALOADERS | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, verbose: bool = True, datamodule: AnomalibDataModule | None = None, ) -> _EVALUATE_OUTPUT | None: @@ -487,6 +556,8 @@ def validate( anomalib validate --config ``` """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() if model: self._setup_trainer(model) self._setup_dataset_task(dataloaders) @@ -497,7 +568,7 @@ def test( self, model: AnomalyModule | None = None, dataloaders: EVAL_DATALOADERS | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, verbose: bool = True, datamodule: AnomalibDataModule | None = None, ) -> _EVALUATE_OUTPUT: @@ -572,11 +643,17 @@ def test( anomalib test --config ``` """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace(model=model or self.model, datamodule=datamodule, test_dataloaders=dataloaders) + if model: self._setup_trainer(model) elif not self.model: msg = "`Engine.test()` requires an `AnomalyModule` when it hasn't been passed in a previous run." raise RuntimeError(msg) + self._setup_dataset_task(dataloaders) self._setup_transform(model or self.model, datamodule=datamodule, ckpt_path=ckpt_path) if self._should_run_validation(model or self.model, dataloaders, datamodule, ckpt_path): @@ -591,7 +668,7 @@ def predict( datamodule: AnomalibDataModule | None = None, dataset: Dataset | PredictDataset | None = None, return_predictions: bool | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, ) -> _PREDICT_OUTPUT | None: """Predict using the model using the trainer. @@ -651,6 +728,12 @@ def predict( assert ( model or self.model ), "`Engine.predict()` requires an `AnomalyModule` when it hasn't been passed in a previous run." + + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + + self._setup_workspace(model=model or self.model, datamodule=datamodule, test_dataloaders=dataloaders) + if model: self._setup_trainer(model) @@ -692,7 +775,7 @@ def train( val_dataloaders: EVAL_DATALOADERS | None = None, test_dataloaders: EVAL_DATALOADERS | None = None, datamodule: AnomalibDataModule | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, ) -> _EVALUATE_OUTPUT: """Fits the model and then calls test on it. @@ -724,6 +807,16 @@ def train( anomalib train --config ``` """ + if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() + self._setup_workspace( + model, + train_dataloaders, + val_dataloaders, + test_dataloaders, + datamodule, + versioned_dir=True, + ) self._setup_trainer(model) self._setup_dataset_task( train_dataloaders, @@ -746,7 +839,7 @@ def export( export_root: str | Path | None = None, transform: Transform | None = None, ov_args: dict[str, Any] | None = None, - ckpt_path: str | None = None, + ckpt_path: str | Path | None = None, ) -> Path | None: """Export the model in PyTorch, ONNX or OpenVINO format. @@ -759,7 +852,7 @@ def export( the engine will try to use the transform from the datamodule or dataset. Defaults to None. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. - ckpt_path (str | None): Checkpoint path. If provided, the model will be loaded from this path. + ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path. Returns: Path: Path to the exported model. @@ -789,6 +882,7 @@ def export( """ self._setup_trainer(model) if ckpt_path: + ckpt_path = Path(ckpt_path).resolve() model = model.__class__.load_from_checkpoint(ckpt_path) if export_root is None: diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index bcd4c84a8d..46f2fd34ea 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -5,13 +5,13 @@ import logging -import re from importlib import import_module from jsonargparse import Namespace from omegaconf import DictConfig, OmegaConf from anomalib.models.components import AnomalyModule +from anomalib.utils.path import convert_to_snake_case from .image import ( Cfa, @@ -63,25 +63,6 @@ class UnknownModelError(ModuleNotFoundError): logger = logging.getLogger(__name__) -def convert_pascal_to_snake_case(pascal_case: str) -> str: - """Convert PascalCase to snake_case. - - Args: - pascal_case (str): Input string in PascalCase - - Returns: - str: Output string in snake_case - - Examples: - >>> _convert_pascal_to_snake_case("EfficientAd") - efficient_ad - - >>> _convert_pascal_to_snake_case("Patchcore") - patchcore - """ - return re.sub(r"(? str: """Convert snake_case to PascalCase. @@ -111,7 +92,7 @@ def get_available_models() -> set[str]: >>> get_available_models() ['ai_vad', 'cfa', 'cflow', 'csflow', 'dfkde', 'dfm', 'draem', 'efficient_ad', 'fastflow', ...] """ - return {convert_pascal_to_snake_case(cls.__name__) for cls in AnomalyModule.__subclasses__()} + return {convert_to_snake_case(cls.__name__) for cls in AnomalyModule.__subclasses__()} def _get_model_class_by_name(name: str) -> type[AnomalyModule]: diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index 9834d291fd..b246d5c1af 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -56,6 +56,11 @@ def __init__(self) -> None: self._is_setup = False # flag to track if setup has been called from the trainer + @property + def name(self) -> str: + """Name of the model.""" + return self.__class__.__name__ + def setup(self, stage: str | None = None) -> None: """Calls the _setup method to build the model if the model is not already built.""" if getattr(self, "model", None) is None or not self._is_setup: @@ -67,7 +72,7 @@ def setup(self, stage: str | None = None) -> None: def _setup(self) -> None: """The _setup method is used to build the torch model dynamically or adjust something about them. - The model implementer may override this method to build the model. This is useful when the model canot be set + The model implementer may override this method to build the model. This is useful when the model cannot be set in the `__init__` method because it requires some information or data that is not available at the time of initialization. """ diff --git a/src/anomalib/utils/config.py b/src/anomalib/utils/config.py index 708eae6248..27f4605419 100644 --- a/src/anomalib/utils/config.py +++ b/src/anomalib/utils/config.py @@ -6,7 +6,6 @@ import logging from collections.abc import Sequence -from datetime import datetime from pathlib import Path from typing import Any, cast @@ -17,18 +16,6 @@ logger = logging.getLogger(__name__) -def get_default_root_directory(config: DictConfig | ListConfig) -> Path: - """Set the default root directory.""" - root_dir = config.results_dir.path if config.results_dir.path else "./results" - model_name = config.model.class_path.split(".")[-1].lower() - data_name = config.data.class_path.split(".")[-1].lower() - category = config.data.init_args.category if "category" in config.data.init_args else "" - # add datetime folder to the path as well so that runs with same configuration are not overwritten - time_stamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") if config.results_dir.unique else "" - # loggers should write to results/model/dataset/category/ folder - return Path(root_dir, model_name, data_name, category, time_stamp) - - def _convert_nested_path_to_str(config: Any) -> Any: # noqa: ANN401 """Goes over the dictionary and converts all path values to str.""" if isinstance(config, dict): @@ -99,29 +86,7 @@ def update_config(config: DictConfig | ListConfig | Namespace) -> DictConfig | L """ _show_warnings(config) - # keep track of the original config file because it will be modified - config_original: DictConfig | ListConfig | Namespace = ( - config.copy() if isinstance(config, DictConfig | ListConfig) else config.clone() - ) - - # Project Configs - project_path = get_default_root_directory(config) - logger.info(f"Project path set to {(project_path)}") - - (project_path / "weights").mkdir(parents=True, exist_ok=True) - (project_path / "images").mkdir(parents=True, exist_ok=True) - - config.trainer.default_root_dir = str(project_path) - config.results_dir.path = str(project_path) - - config = _update_nncf_config(config) - - # write the original config for eventual debug (modified config at the end of the function) - (project_path / "config_original.yaml").write_text(to_yaml(config_original)) - - (project_path / "config.yaml").write_text(to_yaml(config)) - - return config + return _update_nncf_config(config) def _update_nncf_config(config: DictConfig | ListConfig) -> DictConfig | ListConfig: diff --git a/src/anomalib/utils/path.py b/src/anomalib/utils/path.py new file mode 100644 index 0000000000..47cc77652f --- /dev/null +++ b/src/anomalib/utils/path.py @@ -0,0 +1,97 @@ +"""Anomalib Path Utils.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re +from pathlib import Path + + +def create_versioned_dir(root_dir: str | Path) -> Path: + """Create a new version directory and update the ``latest`` symbolic link. + + Args: + root_dir (Path): The root directory where the version directories are stored. + + Returns: + latest_link_path (Path): The path to the ``latest`` symbolic link. + + Examples: + >>> version_dir = create_version_dir(Path('path/to/experiments/')) + PosixPath('/path/to/experiments/latest') + + >>> version_dir.resolve().name + v1 + + Calling the function again will create a new version directory and + update the ``latest`` symbolic link: + + >>> version_dir = create_version_dir('path/to/experiments/') + PosixPath('/path/to/experiments/latest') + + >>> version_dir.resolve().name + v2 + + """ + # Compile a regular expression to match version directories + version_pattern = re.compile(r"^v(\d+)$") + + # Resolve the path + root_dir = Path(root_dir).resolve() + root_dir.mkdir(parents=True, exist_ok=True) + + # Find the highest existing version number + highest_version = -1 + for version_dir in root_dir.iterdir(): + if version_dir.is_dir(): + match = version_pattern.match(version_dir.name) + if match: + version_number = int(match.group(1)) + highest_version = max(highest_version, version_number) + + # The new directory will have the next highest version number + new_version_number = highest_version + 1 + new_version_dir = root_dir / f"v{new_version_number}" + + # Create the new version directory + new_version_dir.mkdir() + + # Update the 'latest' symbolic link to point to the new version directory + latest_link_path = root_dir / "latest" + if latest_link_path.is_symlink() or latest_link_path.exists(): + latest_link_path.unlink() + latest_link_path.symlink_to(new_version_dir, target_is_directory=True) + + return latest_link_path + + +def convert_to_snake_case(s: str) -> str: + """Converts a string to snake case. + + Args: + s (str): The input string to be converted. + + Returns: + str: The converted string in snake case. + + Examples: + >>> convert_to_snake_case("Snake Case") + 'snake_case' + + >>> convert_to_snake_case("snakeCase") + 'snake_case' + + >>> convert_to_snake_case("snake_case") + 'snake_case' + """ + # Replace whitespace, hyphens, periods, and apostrophes with underscores + s = re.sub(r"\s+|[-.\']", "_", s) + + # Insert underscores before capital letters (except at the beginning of the string) + s = re.sub(r"(? np.ndarray: image_grid.add_image(image=image_result.segmentations, title="Segmentation Result") elif self.task == TaskType.CLASSIFICATION: image_grid.add_image(image_result.image, title="Image") - if hasattr(image_result, "heat_map"): + if image_result.heat_map is not None: image_grid.add_image(image_result.heat_map, "Predicted Heat Map") if image_result.pred_label: image_classified = add_anomalous_label(image_result.image, image_result.pred_score) diff --git a/tests/conftest.py b/tests/conftest.py index 472001360c..8367539957 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,6 @@ from tempfile import TemporaryDirectory import pytest -from lightning.pytorch.callbacks import ModelCheckpoint from anomalib.data import ImageDataFormat, MVTec, VideoDataFormat from anomalib.engine import Engine @@ -88,23 +87,14 @@ def checkpoint(model_name: str) -> Path: Since integration tests train all the models, model training occurs when running unit tests invididually. """ - _ckpt_path = project_path / model_name.lower() / "dummy" / "weights" / "last.ckpt" + model = get_model(model_name) + _ckpt_path = project_path / model.name / "MVTec" / "dummy" / "latest" / "weights" / "lightning" / "model.ckpt" if not _ckpt_path.exists(): - model = get_model(model_name) engine = Engine( logger=False, default_root_dir=project_path, max_epochs=1, devices=1, - callbacks=[ - ModelCheckpoint( - dirpath=project_path / model_name.lower() / "dummy" / "weights", - monitor=None, - filename="last", - save_last=True, - auto_insert_metric_name=False, - ), - ], ) dataset = MVTec(root=dataset_path / "mvtec", category="dummy") engine.fit(model=model, datamodule=dataset) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index daf9151fbe..a163ef73ad 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -46,7 +46,7 @@ def test_test(self, dataset_path: Path, project_path: Path) -> None: "test", *self._get_common_cli_args(dataset_path, project_path), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -63,7 +63,7 @@ def test_train(self, dataset_path: Path, project_path: Path) -> None: "train", *self._get_common_cli_args(dataset_path, project_path), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -80,7 +80,7 @@ def test_validate(self, dataset_path: Path, project_path: Path) -> None: "validate", *self._get_common_cli_args(dataset_path, project_path), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -103,7 +103,7 @@ def test_predict_with_dataloader(self, dataset_path: Path, project_path: Path) - project_path, ), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -127,7 +127,7 @@ def test_predict_with_image_folder(self, project_path: Path) -> None: project_path, ), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -151,7 +151,7 @@ def test_predict_with_image_path(self, project_path: Path) -> None: project_path, ), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) torch.cuda.empty_cache() @@ -176,7 +176,7 @@ def test_export( export_type, *self._get_common_cli_args(None, project_path), "--ckpt_path", - f"{project_path}/padim/dummy/weights/last.ckpt", + f"{project_path}/Padim/MVTec/dummy/v0/weights/lightning/model.ckpt", ], ) @@ -208,23 +208,10 @@ def _get_common_cli_args(dataset_path: Path | None, project_path: Path) -> list[ "--model", "Padim", *data_args, - "--results_dir.path", + "--default_root_dir", str(project_path), - "--results_dir.unique", - "false", "--task", "SEGMENTATION", "--trainer.max_epochs", "1", - "--trainer.callbacks+=anomalib.callbacks.ModelCheckpoint", - "--trainer.callbacks.dirpath", - f"{project_path}/padim/dummy/weights", - "--trainer.callbacks.monitor", - "null", - "--trainer.callbacks.filename", - "last", - "--trainer.callbacks.save_last", - "true", - "--trainer.callbacks.auto_insert_metric_name", - "false", ] diff --git a/tests/integration/model/test_models.py b/tests/integration/model/test_models.py index b135c6e5e0..4be832d20b 100644 --- a/tests/integration/model/test_models.py +++ b/tests/integration/model/test_models.py @@ -12,8 +12,7 @@ import pytest from anomalib import TaskType -from anomalib.callbacks import ModelCheckpoint -from anomalib.data import AnomalibDataModule, MVTec, UCSDped +from anomalib.data import AnomalibDataModule, MVTec from anomalib.deploy.export import ExportType from anomalib.engine import Engine from anomalib.models import AnomalyModule, get_available_models, get_model @@ -62,7 +61,11 @@ def test_test(self, model_name: str, dataset_path: Path, project_path: Path) -> dataset_path=dataset_path, project_path=project_path, ) - engine.test(model=model, datamodule=dataset, ckpt_path=f"{project_path}/{model_name}/dummy/weights/last.ckpt") + engine.test( + model=model, + datamodule=dataset, + ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", + ) @pytest.mark.parametrize("model_name", models()) def test_train(self, model_name: str, dataset_path: Path, project_path: Path) -> None: @@ -78,7 +81,11 @@ def test_train(self, model_name: str, dataset_path: Path, project_path: Path) -> dataset_path=dataset_path, project_path=project_path, ) - engine.train(model=model, datamodule=dataset, ckpt_path=f"{project_path}/{model_name}/dummy/weights/last.ckpt") + engine.train( + model=model, + datamodule=dataset, + ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", + ) @pytest.mark.parametrize("model_name", models()) def test_validate(self, model_name: str, dataset_path: Path, project_path: Path) -> None: @@ -97,7 +104,7 @@ def test_validate(self, model_name: str, dataset_path: Path, project_path: Path) engine.validate( model=model, datamodule=dataset, - ckpt_path=f"{project_path}/{model_name}/dummy/weights/last.ckpt", + ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", ) @pytest.mark.parametrize("model_name", models()) @@ -116,7 +123,7 @@ def test_predict(self, model_name: str, dataset_path: Path, project_path: Path) ) engine.predict( model=model, - ckpt_path=f"{project_path}/{model_name}/dummy/weights/last.ckpt", + ckpt_path=f"{project_path}/{model.name}/{datamodule.name}/dummy/v0/weights/lightning/model.ckpt", datamodule=datamodule, ) @@ -141,19 +148,17 @@ def test_export( # TODO(ashwinvaidya17): Restore this test after fixing reverse distillation # https://github.com/openvinotoolkit/anomalib/issues/1513 pytest.skip("Reverse distillation fails to convert to ONNX") - elif model_name == "ai_vad": - pytest.skip("Export fails for video models.") elif model_name == "rkde" and export_type == ExportType.OPENVINO: pytest.skip("RKDE fails to convert to OpenVINO") - model, _, engine = self._get_objects( + model, dataset, engine = self._get_objects( model_name=model_name, dataset_path=dataset_path, project_path=project_path, ) engine.export( model=model, - ckpt_path=f"{project_path}/{model_name}/dummy/weights/last.ckpt", + ckpt_path=f"{project_path}/{model.name}/{dataset.name}/dummy/v0/weights/lightning/model.ckpt", export_type=export_type, ) @@ -189,11 +194,10 @@ def _get_objects( extra_args = {} if model_name in ("rkde", "dfkde"): extra_args["n_pca_components"] = 2 + if model_name == "ai_vad": + pytest.skip("Revisit AI-VAD test") # select dataset - if model_name == "ai_vad": - # aivad expects UCSD dataset - dataset = UCSDped(root=dataset_path / "ucsdped", category="dummy", task=task_type) elif model_name == "win_clip": dataset = MVTec(root=dataset_path / "mvtec", category="dummy", image_size=240, task=task_type) else: @@ -214,15 +218,6 @@ def _get_objects( devices=1, pixel_metrics=["F1Score", "AUROC"], task=task_type, - callbacks=[ - ModelCheckpoint( - dirpath=f"{project_path}/{model_name}/dummy/weights", - monitor=None, - filename="last", - save_last=True, - auto_insert_metric_name=False, - ), - ], # TODO(ashwinvaidya17): Fix these Edge cases # https://github.com/openvinotoolkit/anomalib/issues/1478 max_steps=70000 if model_name == "efficient_ad" else -1, diff --git a/tests/integration/tools/test_gradio_entrypoint.py b/tests/integration/tools/test_gradio_entrypoint.py index 88776e5194..bbdcdd1444 100644 --- a/tests/integration/tools/test_gradio_entrypoint.py +++ b/tests/integration/tools/test_gradio_entrypoint.py @@ -47,14 +47,14 @@ def test_torch_inference( # export torch model export_to_torch( model=model, - export_root=_ckpt_path.parent.parent, + export_root=_ckpt_path.parent.parent.parent, task=TaskType.SEGMENTATION, ) arguments = parser().parse_args( [ "--weights", - str(_ckpt_path.parent) + "/torch/model.pt", + str(_ckpt_path.parent.parent) + "/torch/model.pt", ], ) assert isinstance(inferencer(arguments.weights, arguments.metadata), TorchInferencer) @@ -71,7 +71,7 @@ def test_openvino_inference( # export OpenVINO model export_to_openvino( - export_root=_ckpt_path.parent.parent, + export_root=_ckpt_path.parent.parent.parent, model=model, ov_args={}, task=TaskType.SEGMENTATION, @@ -80,9 +80,9 @@ def test_openvino_inference( arguments = parser().parse_args( [ "--weights", - str(_ckpt_path.parent) + "/openvino/model.bin", + str(_ckpt_path.parent.parent) + "/openvino/model.bin", "--metadata", - str(_ckpt_path.parent) + "/openvino/metadata.json", + str(_ckpt_path.parent.parent) + "/openvino/metadata.json", ], ) assert isinstance(inferencer(arguments.weights, arguments.metadata), OpenVINOInferencer) diff --git a/tests/integration/tools/test_openvino_entrypoint.py b/tests/integration/tools/test_openvino_entrypoint.py index 704d6ad05b..31da21a138 100644 --- a/tests/integration/tools/test_openvino_entrypoint.py +++ b/tests/integration/tools/test_openvino_entrypoint.py @@ -44,7 +44,7 @@ def test_openvino_inference( # export OpenVINO model export_to_openvino( - export_root=_ckpt_path.parent.parent, + export_root=_ckpt_path.parent.parent.parent, model=model, ov_args={}, task=TaskType.SEGMENTATION, @@ -53,13 +53,13 @@ def test_openvino_inference( arguments = get_parser().parse_args( [ "--weights", - str(_ckpt_path.parent) + "/openvino/model.bin", + str(_ckpt_path.parent.parent) + "/openvino/model.bin", "--metadata", - str(_ckpt_path.parent) + "/openvino/metadata.json", + str(_ckpt_path.parent.parent) + "/openvino/metadata.json", "--input", get_dummy_inference_image, "--output", - str(_ckpt_path.parent) + "/output.png", + str(_ckpt_path.parent.parent) + "/output.png", ], ) infer(arguments) diff --git a/tests/integration/tools/test_torch_entrypoint.py b/tests/integration/tools/test_torch_entrypoint.py index 2674f75957..5a8f5848a8 100644 --- a/tests/integration/tools/test_torch_entrypoint.py +++ b/tests/integration/tools/test_torch_entrypoint.py @@ -44,13 +44,13 @@ def test_torch_inference( model = Padim.load_from_checkpoint(_ckpt_path) export_to_torch( model=model, - export_root=_ckpt_path.parent.parent, + export_root=_ckpt_path.parent.parent.parent, task=TaskType.SEGMENTATION, ) arguments = get_parser().parse_args( [ "--weights", - str(_ckpt_path.parent) + "/torch/model.pt", + str(_ckpt_path.parent.parent) + "/torch/model.pt", "--input", get_dummy_inference_image, "--output", diff --git a/tests/unit/data/image/test_folder.py b/tests/unit/data/image/test_folder.py index 3c0825b8f2..8a324cfc0f 100644 --- a/tests/unit/data/image/test_folder.py +++ b/tests/unit/data/image/test_folder.py @@ -27,6 +27,7 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> Folder: # Create and prepare the dataset _datamodule = Folder( + name="dummy", root=dataset_path / "mvtec" / "dummy", normal_dir="train/good", abnormal_dir="test/bad", diff --git a/tests/unit/data/image/test_folder_3d.py b/tests/unit/data/image/test_folder_3d.py index 7255fe13a0..0bb806ebd7 100644 --- a/tests/unit/data/image/test_folder_3d.py +++ b/tests/unit/data/image/test_folder_3d.py @@ -19,6 +19,7 @@ class TestFolder3D(_TestAnomalibDepthDatamodule): def datamodule(self, dataset_path: Path, task_type: TaskType) -> Folder3D: """Create and return a Folder 3D datamodule.""" _datamodule = Folder3D( + name="dummy", root=dataset_path / "mvtec_3d/dummy", normal_dir="train/good/rgb", abnormal_dir="test/bad/rgb", diff --git a/tests/unit/data/utils/test_synthetic.py b/tests/unit/data/utils/test_synthetic.py index 932db031f8..599cf5cc68 100644 --- a/tests/unit/data/utils/test_synthetic.py +++ b/tests/unit/data/utils/test_synthetic.py @@ -17,6 +17,7 @@ def folder_dataset(dataset_path: Path) -> FolderDataset: """Fixture that returns a FolderDataset instance.""" return FolderDataset( + name="dummy", task=TaskType.SEGMENTATION, root=dataset_path / "mvtec" / "dummy", normal_dir="train/good", diff --git a/tests/unit/utils/callbacks/visualizer_callback/test_visualizer.py b/tests/unit/utils/callbacks/visualizer_callback/test_visualizer.py index 13abab1d5a..2ae53e90fc 100644 --- a/tests/unit/utils/callbacks/visualizer_callback/test_visualizer.py +++ b/tests/unit/utils/callbacks/visualizer_callback/test_visualizer.py @@ -12,7 +12,6 @@ from anomalib.data import MVTec from anomalib.engine import Engine from anomalib.loggers import AnomalibTensorBoardLogger -from anomalib.utils.visualization.image import ImageVisualizer from .dummy_lightning_model import DummyModule @@ -25,12 +24,9 @@ def test_add_images(task: TaskType, dataset_path: Path) -> None: model = DummyModule(dataset_path) engine = Engine( logger=logger, - enable_checkpointing=False, default_root_dir=dir_loc, task=task, limit_test_batches=1, - save_image=True, - visualizers=ImageVisualizer(), accelerator="cpu", ) engine.test(model=model, datamodule=MVTec(root=dataset_path / "mvtec", category="dummy")) diff --git a/tests/unit/utils/test_visualizer.py b/tests/unit/utils/test_visualizer.py index 860ba903dd..19a905e558 100644 --- a/tests/unit/utils/test_visualizer.py +++ b/tests/unit/utils/test_visualizer.py @@ -15,7 +15,7 @@ from anomalib.data import MVTec, PredictDataset from anomalib.engine import Engine from anomalib.models import get_model -from anomalib.utils.visualization.image import ImageVisualizer, VisualizationMode, _ImageGrid +from anomalib.utils.visualization.image import _ImageGrid def test_visualize_fully_defected_masks() -> None: @@ -39,14 +39,12 @@ class TestVisualizer: """Test visualization callback for test and predict with different task types.""" @pytest.mark.parametrize("task", [TaskType.CLASSIFICATION, TaskType.SEGMENTATION, TaskType.DETECTION]) - @pytest.mark.parametrize("mode", [VisualizationMode.FULL, VisualizationMode.SIMPLE]) def test_model_visualizer_mode( self, ckpt_path: Callable[[str], Path], project_path: Path, dataset_path: Path, task: TaskType, - mode: VisualizationMode, ) -> None: """Test combination of model/visualizer/mode on only 1 epoch as a sanity check before merge.""" _ckpt_path: Path = ckpt_path("Padim") @@ -54,8 +52,6 @@ def test_model_visualizer_mode( engine = Engine( default_root_dir=project_path, fast_dev_run=True, - visualizers=ImageVisualizer(mode=mode), - save_image=True, devices=1, task=task, ) diff --git a/tools/inference/lightning_inference.py b/tools/inference/lightning_inference.py index acea297942..4f5103dd74 100644 --- a/tools/inference/lightning_inference.py +++ b/tools/inference/lightning_inference.py @@ -11,7 +11,6 @@ from anomalib.data import PredictDataset from anomalib.engine import Engine from anomalib.models import AnomalyModule, get_model -from anomalib.utils.visualization import ImageVisualizer def get_parser() -> LightningArgumentParser: @@ -26,14 +25,6 @@ def get_parser() -> LightningArgumentParser: parser.add_argument("--ckpt_path", type=str, required=True, help="Path to model weights") parser.add_class_arguments(PredictDataset, "--data", instantiate=False) parser.add_argument("--output", type=str, required=False, help="Path to save the output image(s).") - parser.add_argument( - "--visualization_mode", - type=str, - required=False, - default="simple", - help="Visualization mode.", - choices=["full", "simple"], - ) parser.add_argument( "--show", action="store_true", @@ -52,14 +43,10 @@ def get_parser() -> LightningArgumentParser: def infer(args: Namespace) -> None: """Run inference.""" - save_images = bool(args.output) - callbacks = None if not hasattr(args, "callbacks") else args.callbacks engine = Engine( default_root_dir=args.output, callbacks=callbacks, - visualizers=ImageVisualizer(mode=args.visualization_mode), - save_image=save_images, devices=1, ) model = get_model(args.model)