Skip to content

Commit

Permalink
Refactor and restructure anomalib.data (#2302)
Browse files Browse the repository at this point in the history
* Move datamodules to datamodule sub-package

* Move datamodules to datamodule sub-package

* Split datamodules and datasets

* Restructure dataclasses to data

* Fix relative imports

* Use absolute imports

* Add datasets dir

* Add relative imports for torch datasets

* Update src/anomalib/data/datamodules/base/__init__.py

Co-authored-by: Ashwin Vaidya <[email protected]>

---------

Co-authored-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
2 people authored and djdameln committed Sep 12, 2024
1 parent 1500db6 commit 627be88
Show file tree
Hide file tree
Showing 107 changed files with 3,326 additions and 2,867 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Project related
datasets
!src/anomalib/data/datasets
pre_trained
!anomalib/datasets
results
!anomalib/core/results

# Test-related files and directories
tmp*
Expand Down
2 changes: 1 addition & 1 deletion notebooks/100_datamodules/101_btech.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"from torchvision.transforms.v2 import Resize\n",
"from torchvision.transforms.v2.functional import to_pil_image\n",
"\n",
"from anomalib.data.image.btech import BTech, BTechDataset\n",
"from anomalib.data import BTech, BTechDataset\n",
"from anomalib import TaskType"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/100_datamodules/102_mvtec.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"from torchvision.transforms.v2 import Resize\n",
"from torchvision.transforms.v2.functional import to_pil_image\n",
"\n",
"from anomalib.data.image.mvtec import MVTec, MVTecDataset\n",
"from anomalib.data import MVTec, MVTecDataset\n",
"from anomalib import TaskType"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/100_datamodules/103_folder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"from torchvision.transforms.v2 import Resize\n",
"from torchvision.transforms.v2.functional import to_pil_image\n",
"\n",
"from anomalib.data.image.folder import Folder, FolderDataset\n",
"from anomalib.data import Folder, FolderDataset\n",
"from anomalib import TaskType"
]
},
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib import TaskType
from anomalib.dataclasses import Batch
from anomalib.data import Batch
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
from anomalib.models import AnomalyModule

Expand Down
60 changes: 55 additions & 5 deletions src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,36 @@

from anomalib.utils.config import to_tuple

from .base import AnomalibDataModule, AnomalibDataset
from .depth import DepthDataFormat, Folder3D, MVTec3D
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
# Dataclasses
from .dataclasses import (
Batch,
DatasetItem,
DepthBatch,
DepthItem,
ImageBatch,
ImageItem,
InferenceBatch,
NumpyImageBatch,
NumpyImageItem,
NumpyVideoBatch,
NumpyVideoItem,
VideoBatch,
VideoItem,
)

# Datamodules
from .datamodules.base import AnomalibDataModule
from .datamodules.depth import DepthDataFormat, Folder3D, MVTec3D
from .datamodules.image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
from .datamodules.video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat

# Datasets
from .datasets import AnomalibDataset
from .datasets.depth import Folder3DDataset, MVTec3DDataset
from .datasets.image import BTechDataset, FolderDataset, KolektorDataset, MVTecDataset, VisaDataset
from .datasets.video import AvenueDataset, ShanghaiTechDataset, UCSDpedDataset
from .predict import PredictDataset
from .utils import LabelName
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +87,34 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule


__all__ = [
# Anomalib dataclasses
"DatasetItem",
"Batch",
"InferenceBatch",
"ImageItem",
"ImageBatch",
"VideoItem",
"VideoBatch",
"DepthItem",
"DepthBatch",
"NumpyImageItem",
"NumpyImageBatch",
"NumpyVideoItem",
"NumpyVideoBatch",
# Anomalib datasets
"AnomalibDataset",
"Folder3DDataset",
"MVTec3DDataset",
"BTechDataset",
"FolderDataset",
"KolektorDataset",
"MVTecDataset",
"VisaDataset",
"AvenueDataset",
"ShanghaiTechDataset",
"UCSDpedDataset",
"PredictDataset",
# Anomalib datamodules
"AnomalibDataModule",
"DepthDataFormat",
"ImageDataFormat",
Expand All @@ -72,7 +123,6 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule
"BTech",
"Folder",
"Folder3D",
"PredictDataset",
"Kolektor",
"MVTec",
"MVTec3D",
Expand Down
17 changes: 0 additions & 17 deletions src/anomalib/data/base/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,19 @@
)

__all__ = [
# Numpy
"NumpyImageItem",
"NumpyImageBatch",
"NumpyVideoItem",
"NumpyVideoBatch",
# Torch
"DatasetItem",
"Batch",
"InferenceBatch",
"ImageItem",
"ImageBatch",
"VideoItem",
"VideoBatch",
"NumpyImageItem",
"NumpyImageBatch",
"NumpyVideoItem",
"NumpyVideoBatch",
"DepthItem",
"DepthBatch",
]
File renamed without changes.
24 changes: 24 additions & 0 deletions src/anomalib/data/dataclasses/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Numpy-based dataclasses for Anomalib.
This module provides numpy-based implementations of the generic dataclasses
used in Anomalib. These classes are designed to work with numpy arrays for
efficient data handling and processing in anomaly detection tasks.
The module includes the following main classes:
- NumpyItem: Represents a single item in Anomalib datasets using numpy arrays.
- NumpyBatch: Represents a batch of items in Anomalib datasets using numpy arrays.
- NumpyImageItem: Represents a single image item with additional image-specific fields.
- NumpyImageBatch: Represents a batch of image items with batch operations.
- NumpyVideoItem: Represents a single video item with video-specific fields.
- NumpyVideoBatch: Represents a batch of video items with video-specific operations.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .base import NumpyBatch, NumpyItem
from .image import NumpyImageBatch, NumpyImageItem
from .video import NumpyVideoBatch, NumpyVideoItem

__all__ = ["NumpyBatch", "NumpyItem", "NumpyImageBatch", "NumpyImageItem", "NumpyVideoBatch", "NumpyVideoItem"]
36 changes: 36 additions & 0 deletions src/anomalib/data/dataclasses/numpy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Numpy-based dataclasses for Anomalib.
This module provides numpy-based implementations of the generic dataclasses
used in Anomalib. These classes are designed to work with numpy arrays for
efficient data handling and processing in anomaly detection tasks.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import numpy as np

from anomalib.data.dataclasses.generic import _GenericBatch, _GenericItem


@dataclass
class NumpyItem(_GenericItem[np.ndarray, np.ndarray, np.ndarray, str]):
"""Dataclass for a single item in Anomalib datasets using numpy arrays.
This class extends _GenericItem for numpy-based data representation. It includes
both input data (e.g., images, labels) and output data (e.g., predictions,
anomaly maps) as numpy arrays. It is suitable for numpy-based processing
pipelines in Anomalib.
"""


@dataclass
class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]):
"""Dataclass for a batch of items in Anomalib datasets using numpy arrays.
This class extends _GenericBatch for batches of numpy-based data. It represents
multiple data points for batch processing in anomaly detection tasks. It includes
an additional dimension for batch size in all tensor-like fields.
"""
4 changes: 4 additions & 0 deletions src/anomalib/data/dataclasses/numpy/depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Numpy-based depth dataclasses for Anomalib."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -1,18 +1,4 @@
"""Numpy-based dataclasses for Anomalib.
This module provides numpy-based implementations of the generic dataclasses
used in Anomalib. These classes are designed to work with numpy arrays for
efficient data handling and processing in anomaly detection tasks.
The module includes the following main classes:
- NumpyItem: Represents a single item in Anomalib datasets using numpy arrays.
- NumpyBatch: Represents a batch of items in Anomalib datasets using numpy arrays.
- NumpyImageItem: Represents a single image item with additional image-specific fields.
- NumpyImageBatch: Represents a batch of image items with batch operations.
- NumpyVideoItem: Represents a single video item with video-specific fields.
- NumpyVideoBatch: Represents a batch of video items with video-specific operations.
"""
"""Numpy-based image dataclasses for Anomalib."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
Expand All @@ -21,28 +7,8 @@

import numpy as np

from .generic import BatchIterateMixin, _GenericBatch, _GenericItem, _ImageInputFields, _VideoInputFields


@dataclass
class NumpyItem(_GenericItem[np.ndarray, np.ndarray, np.ndarray, str]):
"""Dataclass for a single item in Anomalib datasets using numpy arrays.
This class extends _GenericItem for numpy-based data representation. It includes
both input data (e.g., images, labels) and output data (e.g., predictions,
anomaly maps) as numpy arrays. It is suitable for numpy-based processing
pipelines in Anomalib.
"""


@dataclass
class NumpyBatch(_GenericBatch[np.ndarray, np.ndarray, np.ndarray, list[str]]):
"""Dataclass for a batch of items in Anomalib datasets using numpy arrays.
This class extends _GenericBatch for batches of numpy-based data. It represents
multiple data points for batch processing in anomaly detection tasks. It includes
an additional dimension for batch size in all tensor-like fields.
"""
from anomalib.data.dataclasses.generic import BatchIterateMixin, _ImageInputFields
from anomalib.data.dataclasses.numpy.base import NumpyBatch, NumpyItem


@dataclass
Expand Down Expand Up @@ -175,56 +141,3 @@ def _validate_pred_label(self, pred_label: np.ndarray) -> np.ndarray:

def _validate_image_path(self, image_path: list[str]) -> list[str]:
return image_path


@dataclass
class NumpyVideoItem(_VideoInputFields[np.ndarray, np.ndarray, np.ndarray, str], NumpyItem):
"""Dataclass for a single video item in Anomalib datasets using numpy arrays.
This class combines _VideoInputFields and NumpyItem for video-based anomaly detection.
It includes video-specific fields and validation methods to ensure proper formatting
for Anomalib's video-based models.
"""

def _validate_image(self, image: np.ndarray) -> np.ndarray:
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: str) -> str:
return mask_path


@dataclass
class NumpyVideoBatch(
BatchIterateMixin[NumpyVideoItem],
_VideoInputFields[np.ndarray, np.ndarray, np.ndarray, list[str]],
NumpyBatch,
):
"""Dataclass for a batch of video items in Anomalib datasets using numpy arrays.
This class combines BatchIterateMixin, _VideoInputFields, and NumpyBatch for batches
of video data. It supports batch operations and iteration over individual NumpyVideoItems.
It ensures proper formatting for Anomalib's video-based models.
"""

item_class = NumpyVideoItem

def _validate_image(self, image: np.ndarray) -> np.ndarray:
return image

def _validate_gt_label(self, gt_label: np.ndarray) -> np.ndarray:
return gt_label

def _validate_gt_mask(self, gt_mask: np.ndarray) -> np.ndarray:
return gt_mask

def _validate_mask_path(self, mask_path: list[str]) -> list[str]:
return mask_path

def _validate_anomaly_map(self, anomaly_map: np.ndarray) -> np.ndarray:
return anomaly_map
Loading

0 comments on commit 627be88

Please sign in to comment.