Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Replace albumentations with torchvision transforms #1706

Merged
merged 114 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
f9afbf4
replace albumentations transforms with torchvision
djdameln Feb 2, 2024
7fe9745
define model-specific transform
djdameln Feb 14, 2024
063207f
include transforms in exported model
djdameln Feb 15, 2024
6b7d5e8
update inferencers
djdameln Feb 19, 2024
270aefb
fix data types
djdameln Feb 19, 2024
42823d0
remove old import
djdameln Feb 19, 2024
93dc073
update synthetic dataset
djdameln Feb 19, 2024
68b1bb6
formatting
djdameln Feb 19, 2024
2e3212c
address pre-commit issues
djdameln Feb 19, 2024
75bbb72
rename methods
djdameln Feb 19, 2024
c89acfe
add todo
djdameln Feb 19, 2024
7a5fd51
add transform arguments to datamodule
djdameln Feb 22, 2024
ac2ec0c
simplify transform retrieval
djdameln Feb 22, 2024
913174b
make transform optional
djdameln Feb 22, 2024
6ec9140
update folder dataset
djdameln Feb 22, 2024
b68480d
read image as float in torch inferencer
djdameln Feb 22, 2024
48758f7
Compose -> Transform
djdameln Feb 22, 2024
146a4f2
fix tests
djdameln Feb 22, 2024
5ceb5dd
fix folder tests
djdameln Feb 22, 2024
6c124c0
remove image_size from config
djdameln Feb 22, 2024
49fef52
fix visualizer tests
djdameln Feb 22, 2024
3bd3f5f
Compose -> Transform
djdameln Feb 22, 2024
77780e6
update license headers
djdameln Feb 22, 2024
47e813f
merge main
djdameln Feb 22, 2024
7b7fcba
disable argument linking for image size
djdameln Feb 23, 2024
7e15a1e
add new engine method for updating transform
djdameln Feb 23, 2024
372daeb
default_transform -> configure_transforms
djdameln Feb 23, 2024
aea245c
update padim
djdameln Feb 23, 2024
ed6fbff
add image_size parameter to datamodule
djdameln Feb 23, 2024
3d0d904
improve logic in transform setup
djdameln Feb 23, 2024
81409b4
simplify datamodule/dataset setup
djdameln Feb 23, 2024
7f489eb
add unit tests for transform setup
djdameln Feb 23, 2024
2ace92f
update tests
djdameln Feb 23, 2024
5702ae5
remove get_transforms
djdameln Feb 23, 2024
9023b95
remove references to old transforms
djdameln Feb 23, 2024
350428a
disable antialias when exporting to onnx/openvino
djdameln Feb 23, 2024
db0b5b7
add default transforms for patchcore and efficientad
djdameln Feb 23, 2024
3a0f225
add model-specific transform to ai-vad
djdameln Feb 23, 2024
2b74e6a
add default transform to base model
djdameln Feb 23, 2024
f038ab9
call setup trainer in export entrypoint
djdameln Feb 23, 2024
282a5dd
add model-specific transforms to winclip model
djdameln Feb 23, 2024
37831d9
add default transform for uflow model
djdameln Feb 23, 2024
f708487
fix centercrop export
djdameln Feb 25, 2024
dc2e165
fix model and tools integration tests
djdameln Feb 25, 2024
6688e15
update cli export test
djdameln Feb 25, 2024
57ceb9c
typing and docstrings
djdameln Feb 26, 2024
791d162
make transforms argument optional in datasets
djdameln Feb 26, 2024
e30cceb
remove commented code
djdameln Feb 26, 2024
427fe0b
use conditional formatting
djdameln Feb 26, 2024
9e09838
read anomaly source images with PIL
djdameln Feb 26, 2024
a95718e
fix synthetic anomaly tests
djdameln Feb 26, 2024
373c288
docstring
djdameln Feb 26, 2024
52c0c6d
typing and docstrings
djdameln Feb 26, 2024
ed6d380
remove whitespace
djdameln Feb 26, 2024
9d98b26
use LabelName
djdameln Feb 26, 2024
b815f44
update and use read_image
djdameln Feb 26, 2024
b6437b8
replace cv2.imread
djdameln Feb 26, 2024
3a60c9d
Fix tests
ashwinvaidya17 Feb 27, 2024
5a57430
Merge pull request #1 from ashwinvaidya17/fix/cli_export_tests
djdameln Feb 27, 2024
bbb5605
update config upgrade tests
djdameln Feb 27, 2024
70f3394
dynamically set input size parameter
djdameln Feb 27, 2024
982dc20
fix minor mistakes
djdameln Feb 27, 2024
662bf1a
include transform in checkpoint
djdameln Feb 28, 2024
8f11033
fix get_model tests
djdameln Feb 28, 2024
febb895
fix viz callback tests
djdameln Feb 28, 2024
73db19d
change workflow to address minor issues
djdameln Feb 28, 2024
1e121d5
update setup_transform tests
djdameln Feb 28, 2024
f915dfc
pass explicit arguments
djdameln Feb 28, 2024
c899e39
type annotation for torch model
djdameln Feb 28, 2024
3098668
use hasatr
djdameln Feb 28, 2024
2ce2747
use getattr instead of hasattr
djdameln Feb 28, 2024
92384ee
update unittest names
djdameln Feb 28, 2024
41796e0
Merge pull request #2 from djdameln/torchvision-transforms-dyn-input-…
djdameln Feb 28, 2024
b699d7e
remove input_size argument from padim model
djdameln Feb 25, 2024
d699782
remove image_size argument from patchcore model
djdameln Feb 25, 2024
7017f37
remove image_size argument from cfa model
djdameln Feb 25, 2024
9c60e49
remove image_size argument from cflow model
djdameln Feb 25, 2024
c93e6b3
remove image_size argument from dfa model
djdameln Feb 26, 2024
b2322d8
remove image_size argument from efficient_ad model
djdameln Feb 26, 2024
2179595
replace albumentations with torchvision in efficient_ad
djdameln Feb 26, 2024
b167e1e
use _setup to build csflow torch model
djdameln Feb 28, 2024
cc96fe0
use _setup to build fastflow torch model
djdameln Feb 28, 2024
56045fd
use _setup to build ganomaly torch model
djdameln Feb 28, 2024
052f0a0
remove input_size parameter from stfpm model
djdameln Feb 28, 2024
b56b7d0
use _setup to build revdist torch model
djdameln Feb 28, 2024
d2d5266
use _setup to build uflow torch model
djdameln Feb 28, 2024
b47001a
add read_mask function
djdameln Feb 28, 2024
2c54025
replace to_tensor
djdameln Feb 28, 2024
0b83a23
Merge branch 'torchvision-transforms-main' into torchvision-transform…
djdameln Feb 28, 2024
368140e
allow default model transform in export
djdameln Feb 28, 2024
7404e2e
fix trainer availability check
djdameln Feb 28, 2024
6a5202f
setup transforms in dataloaders during predict
djdameln Feb 28, 2024
8fcabc2
update cli tests
djdameln Feb 28, 2024
f3b5e40
Merge branch 'main' into torchvision-transforms-main
djdameln Feb 28, 2024
0f09daf
rename updated class
djdameln Feb 28, 2024
cc5b577
remove reference to albumentations from docstrings
djdameln Feb 28, 2024
4f06c16
remove albumentations from conftest
djdameln Feb 28, 2024
73a1e6d
revert to using padim for predict tests
djdameln Feb 28, 2024
42fbd26
fix cli predict on PredictDataset bug
djdameln Feb 28, 2024
31d7aee
add image_size parameter to folder datamodule
djdameln Feb 28, 2024
2ca95e7
update notebooks
djdameln Feb 28, 2024
3d25526
remove albumentations from requirements
djdameln Feb 28, 2024
54ea03a
fix torch inference test
djdameln Feb 28, 2024
46e43cb
read mask as uint8
djdameln Feb 28, 2024
b173f30
create default resize transform in datamodule
djdameln Feb 28, 2024
13274bc
fix fastflow notebook
djdameln Feb 28, 2024
dd3d08e
pass image_size to folder datamodule in notebook
djdameln Feb 28, 2024
b20c069
reduce num workers in notebooks
djdameln Feb 28, 2024
eeb69b5
read numpy image in [0-1] range
djdameln Feb 28, 2024
7d4d568
use read_image in openvino inference
djdameln Feb 28, 2024
b565c1c
update getting_started notebook
djdameln Feb 28, 2024
9d42b49
enable antialias in default transform
djdameln Feb 28, 2024
c92d551
use read_image in PredictDataset
djdameln Feb 28, 2024
96101e7
fix fastflow notebook
djdameln Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions src/anomalib/data/base/dataset.py
djdameln marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from pathlib import Path

import albumentations as A # noqa: N812
import cv2
import numpy as np
import pandas as pd
import torch
from pandas import DataFrame
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from torchvision.tv_tensors import Mask

from anomalib import TaskType
from anomalib.data.utils import masks_to_boxes, read_image
from anomalib.data.utils import masks_to_boxes

_EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"]
_EXPECTED_COLUMNS_SEGMENTATION = [*_EXPECTED_COLUMNS_CLASSIFICATION, "mask_path"]
Expand Down Expand Up @@ -117,24 +118,21 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
mask_path = self._samples.iloc[index].mask_path
label_index = self._samples.iloc[index].label_index

image = read_image(image_path)
image = to_tensor(Image.open(image_path))
djdameln marked this conversation as resolved.
Show resolved Hide resolved
item = {"image_path": image_path, "label": label_index}

if self.task == TaskType.CLASSIFICATION:
transformed = self.transform(image=image)
item["image"] = transformed["image"]
item["image"] = self.transform(image)
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.

mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) / 255.0
mask = mask.astype(np.single)

transformed = self.transform(image=image, mask=mask)

item["image"] = transformed["image"]
mask = (
Mask(torch.zeros(image.shape[-2:]))
if label_index == 0
djdameln marked this conversation as resolved.
Show resolved Hide resolved
else Mask(to_tensor(Image.open(mask_path)).squeeze())
djdameln marked this conversation as resolved.
Show resolved Hide resolved
)
item["image"], item["mask"] = self.transform(image, mask)
djdameln marked this conversation as resolved.
Show resolved Hide resolved
item["mask_path"] = mask_path
item["mask"] = transformed["mask"]

if self.task == TaskType.DETECTION:
# create boxes from masks for detection task
Expand Down
27 changes: 13 additions & 14 deletions src/anomalib/data/base/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from abc import ABC

import albumentations as A # noqa: N812
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor
from torchvision.tv_tensors import Mask

from anomalib import TaskType
from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import masks_to_boxes, read_depth_image, read_image
from anomalib.data.utils import masks_to_boxes, read_depth_image


class AnomalibDepthDataset(AnomalibDataset, ABC):
Expand Down Expand Up @@ -40,25 +41,23 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
label_index = self._samples.iloc[index].label_index
depth_path = self._samples.iloc[index].depth_path

image = read_image(image_path)
depth_image = read_depth_image(depth_path)
image = to_tensor(Image.open(image_path))
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
depth_image = to_tensor(read_depth_image(depth_path))
item = {"image_path": image_path, "depth_path": depth_path, "label": label_index}

if self.task == TaskType.CLASSIFICATION:
transformed = self.transform(image=image, depth_image=depth_image)
item["image"] = transformed["image"]
item["depth_image"] = transformed["depth_image"]
item["image"], item["depth_image"] = self.transform(image, depth_image)
elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION):
# Only Anomalous (1) images have masks in anomaly datasets
# Therefore, create empty mask for Normal (0) images.
mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) / 255.0
mask = (
Mask(torch.zeros(image.shape[-2:]))
if label_index == 0
djdameln marked this conversation as resolved.
Show resolved Hide resolved
else Mask(to_tensor(Image.open(mask_path)).squeeze())
)

transformed = self.transform(image=image, depth_image=depth_image, mask=mask)

item["image"] = transformed["image"]
item["depth_image"] = transformed["depth_image"]
item["image"], item["depth_image"], item["mask"] = self.transform(image, depth_image, mask)
item["mask_path"] = mask_path
item["mask"] = transformed["mask"]

if self.task == TaskType.DETECTION:
# create boxes from masks for detection task
Expand Down
17 changes: 5 additions & 12 deletions src/anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import albumentations as A # noqa: N812
import torch
from pandas import DataFrame
from torchvision.tv_tensors import Mask

from anomalib import TaskType
from anomalib.data.base.datamodule import AnomalibDataModule
Expand Down Expand Up @@ -148,22 +149,14 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]:
item["original_image"] = item["image"].to(torch.uint8)

# apply transforms
if "mask" in item and item["mask"] is not None:
processed_frames = [
self.transform(image=frame.numpy(), mask=mask)
for frame, mask in zip(item["image"], item["mask"], strict=True)
]
item["image"] = torch.stack([item["image"] for item in processed_frames]).squeeze(0)
mask = torch.as_tensor(item["mask"])
item["mask"] = torch.stack([item["mask"] for item in processed_frames]).squeeze(0)
item["label"] = torch.Tensor([1 in frame for frame in mask]).int().squeeze(0)
if item.get("mask") is not None:
item["image"], item["mask"] = self.transform(item["image"], Mask(item["mask"]))
item["label"] = torch.Tensor([1 in frame for frame in item["mask"]]).int().squeeze(0)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
if self.task == TaskType.DETECTION:
item["boxes"], _ = masks_to_boxes(item["mask"])
item["boxes"] = item["boxes"][0] if len(item["boxes"]) == 1 else item["boxes"]
else:
item["image"] = torch.stack(
[self.transform(image=frame.numpy())["image"] for frame in item["image"]],
).squeeze(0)
item["image"] = self.transform(item["image"])

# include only target frame in gt
if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL:
Expand Down
8 changes: 5 additions & 3 deletions src/anomalib/data/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from typing import Any

import albumentations as A # noqa: N812
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms.functional import to_tensor

from anomalib.data.utils import get_image_filenames, get_transforms, read_image
from anomalib.data.utils import get_image_filenames, get_transforms


class PredictDataset(Dataset):
Expand Down Expand Up @@ -46,8 +48,8 @@ def __len__(self) -> int:
def __getitem__(self, index: int) -> dict[str, Any]:
"""Get the image based on the `index`."""
image_filename = self.image_filenames[index]
image = read_image(path=image_filename)
pre_processed = self.transform(image=image)
image = to_tensor(Image.open(image_filename))
pre_processed = {"image": self.transform(image)}
pre_processed["image_path"] = str(image_filename)

return pre_processed
2 changes: 1 addition & 1 deletion src/anomalib/data/utils/augmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def generate_perturbation(

# Load anomaly source image
if anomaly_source_path:
anomaly_source_img = cv2.imread(anomaly_source_path)
anomaly_source_img = cv2.imread(str(anomaly_source_path))
djdameln marked this conversation as resolved.
Show resolved Hide resolved
anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height))
else: # if no anomaly source is specified, we use the perlin noise as anomalous source
anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2)
Expand Down
141 changes: 32 additions & 109 deletions src/anomalib/data/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from enum import Enum

import albumentations as A # noqa: N812
from albumentations.pytorch import ToTensorV2
from omegaconf import DictConfig
import torch
from torchvision.transforms import v2

from anomalib.data.utils.image import get_image_height_and_width

logger = logging.getLogger(__name__)


NORMALIZATION_STATS = {
"imagenet": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
"clip": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)},
}


class InputNormalizationMethod(str, Enum):
"""Normalization method for the input images."""

Expand All @@ -25,18 +29,18 @@ class InputNormalizationMethod(str, Enum):


def get_transforms(
config: str | A.Compose | None = None,
config: str | v2.Compose | None = None,
image_size: int | tuple[int, int] | None = None,
center_crop: int | tuple[int, int] | None = None,
normalization: InputNormalizationMethod = InputNormalizationMethod.IMAGENET,
to_tensor: bool = True,
) -> A.Compose:
_to_tensor: bool = True,
djdameln marked this conversation as resolved.
Show resolved Hide resolved
) -> v2.Compose:
"""Get transforms from config or image size.

Args:
config (str | A.Compose | None, optional):
Albumentations transforms.
Either config or albumentations ``Compose`` object. Defaults to None.
config (str | v2.Compose | None, optional):
Torchvision transforms.
Either config or torchvision ``Compose`` object. Defaults to None.
image_size (int | tuple | None, optional):
Image size to transform.
Defaults to None.
Expand All @@ -46,126 +50,45 @@ def get_transforms(
normalization (InputNormalizationMethod, optional):
Normalization method for the input images.
Defaults to InputNormalizationMethod.IMAGENET.
to_tensor (bool, optional):
_to_tensor (bool, optional):
Boolean to convert the final transforms into Torch tensor.
Defaults to True.

Raises:
ValueError: When both ``config`` and ``image_size`` is ``None``.
ValueError: When ``config`` is not a ``str`` or `A.Compose`` object.

Returns:
A.Compose: Albumentation ``Compose`` object containing the image transforms.

Examples:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
>>> import skimage
>>> image = skimage.data.astronaut()

>>> transforms = get_transforms(image_size=256, to_tensor=False)
>>> output = transforms(image=image)
>>> output["image"].shape
(256, 256, 3)

>>> transforms = get_transforms(image_size=256, to_tensor=True)
>>> output = transforms(image=image)
>>> output["image"].shape
torch.Size([3, 256, 256])


Transforms could be read from albumentations Compose object.

>>> import albumentations as A # noqa: N812
>>> from albumentations.pytorch import ToTensorV2
>>> config = A.Compose([A.Resize(512, 512), ToTensorV2()])
>>> transforms = get_transforms(config=config, to_tensor=False)
>>> output = transforms(image=image)
>>> output["image"].shape
(512, 512, 3)
>>> type(output["image"])
numpy.ndarray

Transforms could be deserialized from a yaml file.

>>> transforms = A.Compose([A.Resize(1024, 1024), ToTensorV2()])
>>> A.save(transforms, "/tmp/transforms.yaml", data_format="yaml")
>>> transforms = get_transforms(config="/tmp/transforms.yaml")
>>> output = transforms(image=image)
>>> output["image"].shape
torch.Size([3, 1024, 1024])
T.Compose: Torchvision Compose object containing the image transforms.
"""
transforms: A.Compose

if config is not None:
if isinstance(config, DictConfig):
logger.info("Loading transforms from config File")
transforms_list = []

if "Resize" not in config and image_size is not None:
resize_height, resize_width = get_image_height_and_width(image_size)
transforms_list.append(A.Resize(height=resize_height, width=resize_width, always_apply=True))
logger.info("Resize %s added!", (resize_height, resize_width))

for key, value in config.items():
if hasattr(A, key):
transform = getattr(A, key)(**value)
logger.info("Transform %s added!", transform)
transforms_list.append(transform)
else:
msg = f"Transformation {key} is not part of albumentations"
raise ValueError(msg)

transforms_list.append(ToTensorV2())
transforms = A.Compose(transforms_list, additional_targets={"image": "image", "depth_image": "image"})

# load transforms from config file
elif isinstance(config, str):
logger.info("Reading transforms from Albumentations config file: %s.", config)
transforms = A.load(filepath=config, data_format="yaml")
elif isinstance(config, A.Compose):
logger.info("Transforms loaded from Albumentations Compose object")
transforms = config
else:
msg = "config could be either ``str`` or ``A.Compose``"
raise TypeError(msg)
# Load torchvision transforms from config
pass # Implement your logic for loading torchvision transforms from a config
else:
logger.info("No config file has been provided. Using default transforms.")
transforms_list = []

# add resize transform
# Add resize transform
if image_size is None:
msg = (
"Both config and image_size cannot be `None`. "
"Provide either config file to de-serialize transforms or image_size to get the default transformations"
)
msg = "Both config and image_size cannot be `None`."
raise ValueError(msg)
resize_height, resize_width = get_image_height_and_width(image_size)
transforms_list.append(A.Resize(height=resize_height, width=resize_width, always_apply=True))
transforms_list.append(v2.Resize(size=(resize_height, resize_width), interpolation=3, antialias=True))
djdameln marked this conversation as resolved.
Show resolved Hide resolved

# add center crop transform
# Add center crop transform
if center_crop is not None:
crop_height, crop_width = get_image_height_and_width(center_crop)
if crop_height > resize_height or crop_width > resize_width:
msg = f"Crop size may not be larger than image size. Found {image_size} and {center_crop}"
raise ValueError(msg)
transforms_list.append(A.CenterCrop(height=crop_height, width=crop_width, always_apply=True))

# add normalize transform
if normalization == InputNormalizationMethod.IMAGENET:
transforms_list.append(A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)))
elif normalization == InputNormalizationMethod.CLIP:
transforms_list.append(
A.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
)
elif normalization == InputNormalizationMethod.NONE:
transforms_list.append(A.ToFloat(max_value=255))
else:
transforms_list.append(v2.CenterCrop(size=(crop_height, crop_width)))

# Add convert-to-float transform
transforms_list.append(v2.ToDtype(dtype=torch.float32, scale=True))

# Add normalize transform
if normalization in [InputNormalizationMethod.IMAGENET, InputNormalizationMethod.CLIP]:
transforms_list.append(v2.Normalize(**NORMALIZATION_STATS[normalization]))
elif normalization != InputNormalizationMethod.NONE:
msg = f"Unknown normalization method: {normalization}"
raise ValueError(msg)

# add tensor conversion
if to_tensor:
transforms_list.append(ToTensorV2())

transforms = A.Compose(transforms_list, additional_targets={"image": "image", "depth_image": "image"})
transforms = v2.Compose(transforms_list)

return transforms
2 changes: 1 addition & 1 deletion src/anomalib/data/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_item(self, idx: int) -> dict[str, Any]:
clip_pts = self.clips[video_idx][clip_idx]

return {
"image": clip,
"image": clip.permute(0, 3, 1, 2) / 255,
"mask": self.get_mask(idx),
"video_path": video_path,
"frames": clip_pts,
Expand Down
Loading
Loading