Skip to content

Commit

Permalink
Configure reference frame for multi-frame video clips (#1023)
Browse files Browse the repository at this point in the history
* make target frame configurable

* add test for gt frame handling

* update changelog

* fix shape inference in visualizer

* change shape inference

* docstring
  • Loading branch information
djdameln authored Apr 24, 2023
1 parent 4fe6c74 commit 7ef59f9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- Configure reference frame for multi-frame video clips (<https://github.com/openvinotoolkit/anomalib/pull/1023>)
- Bump OpenVINO version to `2022.3.0` (<https://github.com/openvinotoolkit/anomalib/pull/932>)
- Remove the dependecy on a specific `torchvision` and `torchmetrics` packages.
- Bump PyTorch Lightning version to v.1.9.\* (<https://github.com/openvinotoolkit/anomalib/pull/870>)
Expand Down
49 changes: 48 additions & 1 deletion src/anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from abc import ABC
from enum import Enum
from typing import Callable

import albumentations as A
Expand All @@ -17,6 +18,18 @@
from anomalib.data.utils.video import ClipsIndexer


class VideoTargetFrame(str, Enum):
"""Target frame for a video-clip.
Used in multi-frame models to determine which frame's ground truth information will be used.
"""

FIRST = "first"
LAST = "last"
MID = "mid"
ALL = "all"


class AnomalibVideoDataset(AnomalibDataset, ABC):
"""Base video anomalib dataset class.
Expand All @@ -25,10 +38,16 @@ class AnomalibVideoDataset(AnomalibDataset, ABC):
transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs.
clip_length_in_frames (int): Number of video frames in each clip.
frames_between_clips (int): Number of frames between each consecutive video clip.
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval
"""

def __init__(
self, task: TaskType, transform: A.Compose, clip_length_in_frames: int, frames_between_clips: int
self,
task: TaskType,
transform: A.Compose,
clip_length_in_frames: int,
frames_between_clips: int,
target_frame=VideoTargetFrame.LAST,
) -> None:
super().__init__(task, transform)

Expand All @@ -39,6 +58,8 @@ def __init__(
self.indexer: ClipsIndexer | None = None
self.indexer_cls: Callable | None = None

self.target_frame = target_frame

def __len__(self) -> int:
"""Get length of the dataset."""
assert isinstance(self.indexer, ClipsIndexer)
Expand Down Expand Up @@ -68,6 +89,28 @@ def _setup_clips(self) -> None:
frames_between_clips=self.frames_between_clips,
)

def _select_targets(self, item):
if self.target_frame == VideoTargetFrame.FIRST:
idx = 0
elif self.target_frame == VideoTargetFrame.LAST:
idx = -1
elif self.target_frame == VideoTargetFrame.MID:
idx = int(self.clip_length_in_frames / 2)
else:
raise ValueError(f"Unknown video target frame: {self.target_frame}")

if item.get("mask") is not None:
item["mask"] = item["mask"][idx, ...]
if item.get("boxes") is not None:
item["boxes"] = item["boxes"][idx]
if item.get("label") is not None:
item["label"] = item["label"][idx]
if item.get("original_image") is not None:
item["original_image"] = item["original_image"][idx]
if item.get("frames") is not None:
item["frames"] = item["frames"][idx]
return item

def __getitem__(self, index: int) -> dict[str, str | Tensor]:
"""Return mask, clip and file system information."""
assert isinstance(self.indexer, ClipsIndexer)
Expand All @@ -93,6 +136,10 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
[self.transform(image=frame.numpy())["image"] for frame in item["image"]]
).squeeze(0)

# include only target frame in gt
if self.clip_length_in_frames > 1 and self.target_frame != VideoTargetFrame.ALL:
item = self._select_targets(item)

if item["mask"] is None:
item.pop("mask")

Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def visualize_batch(self, batch: dict) -> Iterator[np.ndarray]:
Returns:
Generator that yields a display-ready visualization for each image.
"""
batch_size, _num_channels, height, width = batch["image"].size()
batch_size = batch["image"].shape[0]
for i in range(batch_size):
if "image_path" in batch:
height, width = batch["image"].shape[-2:]
image = read_image(path=batch["image_path"][i], image_size=(height, width))
elif "video_path" in batch:
height, width = batch["original_image"].shape[1:3]
image = batch["original_image"][i].squeeze().numpy()
image = cv2.resize(image, dsize=(width, height), interpolation=cv2.INTER_AREA)
else:
Expand Down
17 changes: 17 additions & 0 deletions tests/pre_merge/datasets/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from anomalib.data import TaskType
from anomalib.data.base.video import VideoTargetFrame
from anomalib.data.ucsd_ped import (
UCSDpedClipsIndexer,
UCSDpedDataset,
Expand Down Expand Up @@ -80,3 +81,19 @@ def test_get_item(self, ucsd_dataset, required_keys):
item = next(iter(ucsd_dataset))
# confirm that all the required keys are there
assert set(item.keys()) == set(required_keys)

@pytest.mark.parametrize("split", [Split.TEST])
def test_target_frame(self, ucsd_dataset):
ucsd_dataset.target_frame = VideoTargetFrame.ALL
all_frames_item = ucsd_dataset[0]
ucsd_dataset.target_frame = VideoTargetFrame.FIRST
first_frame_item = ucsd_dataset[0]
ucsd_dataset.target_frame = VideoTargetFrame.LAST
last_frame_item = ucsd_dataset[0]
ucsd_dataset.target_frame = VideoTargetFrame.MID
mid_frame_item = ucsd_dataset[0]

# check if the correct GT frame is retrieved
assert first_frame_item["frames"] == all_frames_item["frames"][0]
assert last_frame_item["frames"] == all_frames_item["frames"][-1]
assert mid_frame_item["frames"] == all_frames_item["frames"][int(len(all_frames_item["frames"]) / 2)]

0 comments on commit 7ef59f9

Please sign in to comment.