diff --git a/CHANGELOG.md b/CHANGELOG.md index cf9807af26..b50bf09ecb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add `VlmAd` metric by [Bepitic](https://github.com/Bepitic) and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2344 - Add `Datumaro` annotation format support by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2377 - Add `AUPIMO` tutorials notebooks in https://github.com/openvinotoolkit/anomalib/pull/2330 and https://github.com/openvinotoolkit/anomalib/pull/2336 - Add `AUPIMO` metric by [jpcbertoldo](https://github.com/jpcbertoldo) in https://github.com/openvinotoolkit/anomalib/pull/1726 and refactored by [ashwinvaidya17](https://github.com/ashwinvaidya17) in https://github.com/openvinotoolkit/anomalib/pull/2329 diff --git a/pyproject.toml b/pyproject.toml index 2893ad20c4..268544ad2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ core = [ "open-clip-torch>=2.23.0,<2.26.1", ] openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"] +vlm = ["ollama", "openai", "python-dotenv","transformers"] loggers = [ "comet-ml>=3.31.7", "gradio>=4", @@ -84,7 +85,7 @@ test = [ "coverage[toml]", "tox", ] -full = ["anomalib[core,openvino,loggers,notebooks]"] +full = ["anomalib[core,openvino,loggers,notebooks, vlm]"] dev = ["anomalib[full,docs,test]"] [project.scripts] diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 5cee830dad..e09e622d41 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -78,9 +78,8 @@ def setup( elif self.task == TaskType.CLASSIFICATION: pixel_metric_names = [] logger.warning( - "Cannot perform pixel-level evaluation when task type is classification. " - "Ignoring the following pixel-level metrics: %s", - self.pixel_metric_names, + "Cannot perform pixel-level evaluation when task type is {self.task.value}. " + f"Ignoring the following pixel-level metrics: {self.pixel_metric_names}", ) else: pixel_metric_names = ( diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 83b9714416..b537819729 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -32,7 +32,7 @@ from anomalib.utils.normalization import NormalizationMethod from anomalib.utils.path import create_versioned_dir from anomalib.utils.types import NORMALIZATION, THRESHOLD -from anomalib.utils.visualization import ImageVisualizer +from anomalib.utils.visualization import BaseVisualizer, ExplanationVisualizer, ImageVisualizer logger = logging.getLogger(__name__) @@ -322,7 +322,7 @@ def _setup_trainer(self, model: AnomalyModule) -> None: self._cache.update(model) # Setup anomalib callbacks to be used with the trainer - self._setup_anomalib_callbacks() + self._setup_anomalib_callbacks(model) # Temporarily set devices to 1 to avoid issues with multiple processes self._cache.args["devices"] = 1 @@ -405,7 +405,7 @@ def _setup_transform( if not getattr(dataloader.dataset, "transform", None): dataloader.dataset.transform = transform - def _setup_anomalib_callbacks(self) -> None: + def _setup_anomalib_callbacks(self, model: AnomalyModule) -> None: """Set up callbacks for the trainer.""" _callbacks: list[Callback] = [] @@ -432,9 +432,17 @@ def _setup_anomalib_callbacks(self) -> None: _callbacks.append(_ThresholdCallback(self.threshold)) _callbacks.append(_MetricsCallback(self.task, self.image_metric_names, self.pixel_metric_names)) + visualizer: BaseVisualizer + + # TODO(ashwinvaidya17): temporary # noqa: TD003 ignoring as visualizer is getting a complete overhaul + if model.__class__.__name__ == "VlmAd": + visualizer = ExplanationVisualizer() + else: + visualizer = ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE) + _callbacks.append( _VisualizationCallback( - visualizers=ImageVisualizer(task=self.task, normalize=self.normalization == NormalizationMethod.NONE), + visualizers=visualizer, save=True, root=self._cache.args["default_root_dir"] / "images", ), diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index b4bb36a875..ea091d1640 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -30,6 +30,7 @@ Rkde, Stfpm, Uflow, + VlmAd, WinClip, ) from .video import AiVad @@ -58,6 +59,7 @@ class UnknownModelError(ModuleNotFoundError): "Stfpm", "Uflow", "AiVad", + "VlmAd", "WinClip", ] diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index f3a5435038..b09da8b07b 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -20,6 +20,7 @@ from .rkde import Rkde from .stfpm import Stfpm from .uflow import Uflow +from .vlm_ad import VlmAd from .winclip import WinClip __all__ = [ @@ -40,5 +41,6 @@ "Rkde", "Stfpm", "Uflow", + "VlmAd", "WinClip", ] diff --git a/src/anomalib/models/image/vlm_ad/__init__.py b/src/anomalib/models/image/vlm_ad/__init__.py new file mode 100644 index 0000000000..46ab8e0fee --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/__init__.py @@ -0,0 +1,8 @@ +"""Visual Anomaly Model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import VlmAd + +__all__ = ["VlmAd"] diff --git a/src/anomalib/models/image/vlm_ad/backends/__init__.py b/src/anomalib/models/image/vlm_ad/backends/__init__.py new file mode 100644 index 0000000000..44009f8f83 --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/__init__.py @@ -0,0 +1,11 @@ +"""VLM backends.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .base import Backend +from .chat_gpt import ChatGPT +from .huggingface import Huggingface +from .ollama import Ollama + +__all__ = ["Backend", "ChatGPT", "Huggingface", "Ollama"] diff --git a/src/anomalib/models/image/vlm_ad/backends/base.py b/src/anomalib/models/image/vlm_ad/backends/base.py new file mode 100644 index 0000000000..b4aadf9a22 --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/base.py @@ -0,0 +1,30 @@ +"""Base backend.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from pathlib import Path + +from anomalib.models.image.vlm_ad.utils import Prompt + + +class Backend(ABC): + """Base backend.""" + + @abstractmethod + def __init__(self, model_name: str) -> None: + """Initialize the backend.""" + + @abstractmethod + def add_reference_images(self, image: str | Path) -> None: + """Add reference images for k-shot.""" + + @abstractmethod + def predict(self, image: str | Path, prompt: Prompt) -> str: + """Predict the anomaly label.""" + + @property + @abstractmethod + def num_reference_images(self) -> int: + """Get the number of reference images.""" diff --git a/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py new file mode 100644 index 0000000000..741288354f --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/chat_gpt.py @@ -0,0 +1,109 @@ +"""ChatGPT backend.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +from dotenv import load_dotenv +from lightning_utilities.core.imports import package_available + +from anomalib.models.image.vlm_ad.utils import Prompt + +from .base import Backend + +if package_available("openai"): + from openai import OpenAI +else: + OpenAI = None + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletion + +logger = logging.getLogger(__name__) + + +class ChatGPT(Backend): + """ChatGPT backend.""" + + def __init__(self, model_name: str, api_key: str | None = None) -> None: + """Initialize the ChatGPT backend.""" + self._ref_images_encoded: list[str] = [] + self.model_name: str = model_name + self._client: OpenAI | None = None + self.api_key = self._get_api_key(api_key) + + @property + def client(self) -> OpenAI: + """Get the OpenAI client.""" + if OpenAI is None: + msg = "OpenAI is not installed. Please install it to use ChatGPT backend." + raise ImportError(msg) + if self._client is None: + self._client = OpenAI(api_key=self.api_key) + return self._client + + def add_reference_images(self, image: str | Path) -> None: + """Add reference images for k-shot.""" + self._ref_images_encoded.append(self._encode_image_to_url(image)) + + @property + def num_reference_images(self) -> int: + """Get the number of reference images.""" + return len(self._ref_images_encoded) + + def predict(self, image: str | Path, prompt: Prompt) -> str: + """Predict the anomaly label.""" + image_encoded = self._encode_image_to_url(image) + messages = [] + + # few-shot + if len(self._ref_images_encoded) > 0: + messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded)) + + messages.append(self._generate_message(content=prompt.predict, images=[image_encoded])) + + response: ChatCompletion = self.client.chat.completions.create(messages=messages, model=self.model_name) + return response.choices[0].message.content + + @staticmethod + def _generate_message(content: str, images: list[str] | None) -> dict: + """Generate a message.""" + message: dict[str, list[dict] | str] = {"role": "user"} + if images is not None: + _content: list[dict[str, str | dict]] = [{"type": "text", "text": content}] + _content.extend([{"type": "image_url", "image_url": {"url": image}} for image in images]) + message["content"] = _content + else: + message["content"] = content + return message + + def _encode_image_to_url(self, image: str | Path) -> str: + """Encode the image to base64 and embed in url string.""" + image_path = Path(image) + extension = image_path.suffix + base64_encoded = self._encode_image_to_base_64(image_path) + return f"data:image/{extension};base64,{base64_encoded}" + + @staticmethod + def _encode_image_to_base_64(image: str | Path) -> str: + """Encode the image to base64.""" + image = Path(image) + return base64.b64encode(image.read_bytes()).decode("utf-8") + + def _get_api_key(self, api_key: str | None = None) -> str: + if api_key is None: + load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + msg = ( + f"OpenAI API key must be provided to use {self.model_name}." + " Please provide the API key in the constructor, or set the OPENAI_API_KEY environment variable" + " or in a `.env` file." + ) + raise ValueError(msg) + return api_key diff --git a/src/anomalib/models/image/vlm_ad/backends/huggingface.py b/src/anomalib/models/image/vlm_ad/backends/huggingface.py new file mode 100644 index 0000000000..c234ecfbc5 --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/huggingface.py @@ -0,0 +1,96 @@ +"""Huggingface backend.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path + +from lightning_utilities.core.imports import package_available +from PIL import Image +from transformers.modeling_utils import PreTrainedModel + +from anomalib.models.image.vlm_ad.utils import Prompt + +from .base import Backend + +if package_available("transformers"): + import transformers + from transformers.modeling_utils import PreTrainedModel + from transformers.processing_utils import ProcessorMixin +else: + transformers = None + + +logger = logging.getLogger(__name__) + + +class Huggingface(Backend): + """Huggingface backend.""" + + def __init__( + self, + model_name: str, + ) -> None: + """Initialize the Huggingface backend.""" + self.model_name: str = model_name + self._ref_images: list[str] = [] + self._processor: ProcessorMixin | None = None + self._model: PreTrainedModel | None = None + + @property + def processor(self) -> ProcessorMixin: + """Get the Huggingface processor.""" + if self._processor is None: + if transformers is None: + msg = "transformers is not installed." + raise ValueError(msg) + self._processor = transformers.LlavaNextProcessor.from_pretrained(self.model_name) + return self._processor + + @property + def model(self) -> PreTrainedModel: + """Get the Huggingface model.""" + if self._model is None: + if transformers is None: + msg = "transformers is not installed." + raise ValueError(msg) + self._model = transformers.LlavaNextForConditionalGeneration.from_pretrained(self.model_name) + return self._model + + @staticmethod + def _generate_message(content: str, images: list[str] | None) -> dict: + """Generate a message.""" + message: dict[str, str | list[dict]] = {"role": "user"} + _content: list[dict[str, str]] = [{"type": "text", "text": content}] + if images is not None: + _content.extend([{"type": "image"} for _ in images]) + message["content"] = _content + return message + + def add_reference_images(self, image: str | Path) -> None: + """Add reference images for k-shot.""" + self._ref_images.append(Image.open(image)) + + @property + def num_reference_images(self) -> int: + """Get the number of reference images.""" + return len(self._ref_images) + + def predict(self, image_path: str | Path, prompt: Prompt) -> str: + """Predict the anomaly label.""" + image = Image.open(image_path) + messages: list[dict] = [] + + if len(self._ref_images) > 0: + messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images)) + + messages.append(self._generate_message(content=prompt.predict, images=[image])) + processed_prompt = [self.processor.apply_chat_template(messages, add_generation_prompt=True)] + + images = [*self._ref_images, image] + inputs = self.processor(images, processed_prompt, return_tensors="pt", padding=True).to(self.model.device) + outputs = self.model.generate(**inputs, max_new_tokens=100) + result = self.processor.decode(outputs[0], skip_special_tokens=True) + print(result) + return result diff --git a/src/anomalib/models/image/vlm_ad/backends/ollama.py b/src/anomalib/models/image/vlm_ad/backends/ollama.py new file mode 100644 index 0000000000..db5a215bb3 --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/backends/ollama.py @@ -0,0 +1,73 @@ +"""Ollama backend. + +Assumes that the Ollama service is running in the background. +See: https://github.com/ollama/ollama +Ensure that ollama is running. On linux: `ollama serve` +On Mac and Windows ensure that the ollama service is running by launching from the application list. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path + +from lightning_utilities.core.imports import package_available + +from anomalib.models.image.vlm_ad.utils import Prompt + +from .base import Backend + +if package_available("ollama"): + from ollama import chat + from ollama._client import _encode_image +else: + chat = None + +logger = logging.getLogger(__name__) + + +class Ollama(Backend): + """Ollama backend.""" + + def __init__(self, model_name: str) -> None: + """Initialize the Ollama backend.""" + self.model_name: str = model_name + self._ref_images_encoded: list[str] = [] + + def add_reference_images(self, image: str | Path) -> None: + """Encode the image to base64.""" + self._ref_images_encoded.append(_encode_image(image)) + + @property + def num_reference_images(self) -> int: + """Get the number of reference images.""" + return len(self._ref_images_encoded) + + @staticmethod + def _generate_message(content: str, images: list[str] | None) -> dict: + """Generate a message.""" + message: dict[str, str | list[str]] = {"role": "user", "content": content} + if images: + message["images"] = images + return message + + def predict(self, image: str | Path, prompt: Prompt) -> str: + """Predict the anomaly label.""" + if not chat: + msg = "Ollama is not installed. Please install it using `pip install ollama`." + raise ImportError(msg) + image_encoded = _encode_image(image) + messages = [] + + # few-shot + if len(self._ref_images_encoded) > 0: + messages.append(self._generate_message(content=prompt.few_shot, images=self._ref_images_encoded)) + + messages.append(self._generate_message(content=prompt.predict, images=[image_encoded])) + + response = chat( + model=self.model_name, + messages=messages, + ) + return response["message"]["content"].strip() diff --git a/src/anomalib/models/image/vlm_ad/lightning_model.py b/src/anomalib/models/image/vlm_ad/lightning_model.py new file mode 100644 index 0000000000..1279f7a31e --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/lightning_model.py @@ -0,0 +1,115 @@ +"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torch.utils.data import DataLoader + +from anomalib import LearningType +from anomalib.models import AnomalyModule + +from .backends import Backend, ChatGPT, Huggingface, Ollama +from .utils import ModelName, Prompt + +logger = logging.getLogger(__name__) + + +class VlmAd(AnomalyModule): + """Visual anomaly model.""" + + def __init__( + self, + model: ModelName | str = ModelName.LLAMA_OLLAMA, + api_key: str | None = None, + k_shot: int = 0, + ) -> None: + super().__init__() + self.k_shot = k_shot + model = ModelName(model) + self.vlm_backend: Backend = self._setup_vlm_backend(model, api_key) + + @staticmethod + def _setup_vlm_backend(model_name: ModelName, api_key: str | None) -> Backend: + if model_name == ModelName.LLAMA_OLLAMA: + return Ollama(model_name=model_name.value) + if model_name == ModelName.GPT_4O_MINI: + return ChatGPT(api_key=api_key, model_name=model_name.value) + if model_name in {ModelName.VICUNA_7B_HF, ModelName.VICUNA_13B_HF, ModelName.MISTRAL_7B_HF}: + return Huggingface(model_name=model_name.value) + + msg = f"Unsupported VLM model: {model_name}" + raise ValueError(msg) + + def _setup(self) -> None: + if self.k_shot > 0 and self.vlm_backend.num_reference_images != self.k_shot: + logger.info("Collecting reference images from training dataset.") + dataloader = self.trainer.datamodule.train_dataloader() + self.collect_reference_images(dataloader) + + def collect_reference_images(self, dataloader: DataLoader) -> None: + """Collect reference images for few-shot inference.""" + for batch in dataloader: + for img_path in batch["image_path"]: + self.vlm_backend.add_reference_images(img_path) + if self.vlm_backend.num_reference_images == self.k_shot: + return + + @property + def prompt(self) -> Prompt: + """Get the prompt.""" + return Prompt( + predict=( + "You are given an image. It is either normal or anomalous." + " First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n" + "Then give the reason for your decision.\n" + "For example, 'YES: The image has a crack on the wall.'" + ), + few_shot=( + "These are a few examples of normal picture without any anomalies." + " You have to use these to determine if the image I provide in the next" + " chat is normal or anomalous." + ), + ) + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict: + """Validation step.""" + del args, kwargs # These variables are not used. + responses = [(self.vlm_backend.predict(img_path, self.prompt)) for img_path in batch["image_path"]] + batch["explanation"] = responses + batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device) + return batch + + @property + def learning_type(self) -> LearningType: + """The learning type of the model.""" + return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT + + @property + def trainer_arguments(self) -> dict[str, int | float]: + """Doesn't need training.""" + return {} + + @staticmethod + def configure_transforms(image_size: tuple[int, int] | None = None) -> None: + """This modes does not require any transforms.""" + if image_size is not None: + logger.warning("Ignoring image_size argument as each backend has its own transforms.") + + @staticmethod + def _export_not_supported_message() -> None: + logging.warning("Exporting the model is not supported for VLM-AD model. Skipping...") + + def to_torch(self, *_, **__) -> None: # type: ignore[override] + """Skip export to torch.""" + return self._export_not_supported_message() + + def to_onnx(self, *_, **__) -> None: # type: ignore[override] + """Skip export to onnx.""" + return self._export_not_supported_message() + + def to_openvino(self, *_, **__) -> None: # type: ignore[override] + """Skip export to openvino.""" + return self._export_not_supported_message() diff --git a/src/anomalib/models/image/vlm_ad/utils.py b/src/anomalib/models/image/vlm_ad/utils.py new file mode 100644 index 0000000000..ce9b9067ac --- /dev/null +++ b/src/anomalib/models/image/vlm_ad/utils.py @@ -0,0 +1,25 @@ +"""Dataclasses.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from enum import Enum + + +@dataclass +class Prompt: + """Prompt.""" + + few_shot: str + predict: str + + +class ModelName(Enum): + """List of supported models.""" + + LLAMA_OLLAMA = "llava" + GPT_4O_MINI = "gpt-4o-mini" + VICUNA_7B_HF = "llava-hf/llava-v1.6-vicuna-7b-hf" + VICUNA_13B_HF = "llava-hf/llava-v1.6-vicuna-13b-hf" + MISTRAL_7B_HF = "llava-hf/llava-v1.6-mistral-7b-hf" diff --git a/src/anomalib/utils/visualization/__init__.py b/src/anomalib/utils/visualization/__init__.py index f68036ed78..404036dfad 100644 --- a/src/anomalib/utils/visualization/__init__.py +++ b/src/anomalib/utils/visualization/__init__.py @@ -4,11 +4,13 @@ # SPDX-License-Identifier: Apache-2.0 from .base import BaseVisualizer, GeneratorResult, VisualizationStep +from .explanation import ExplanationVisualizer from .image import ImageResult, ImageVisualizer from .metrics import MetricsVisualizer __all__ = [ "BaseVisualizer", + "ExplanationVisualizer", "ImageResult", "ImageVisualizer", "GeneratorResult", diff --git a/src/anomalib/utils/visualization/explanation.py b/src/anomalib/utils/visualization/explanation.py new file mode 100644 index 0000000000..10904161e3 --- /dev/null +++ b/src/anomalib/utils/visualization/explanation.py @@ -0,0 +1,106 @@ +"""Explanation visualization generator. + +Note: This is a temporary visualizer, and will be replaced with the new visualizer in the future. +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterator +from pathlib import Path + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from .base import BaseVisualizer, GeneratorResult, VisualizationStep + + +class ExplanationVisualizer(BaseVisualizer): + """Explanation visualization generator.""" + + def __init__(self) -> None: + super().__init__(visualize_on=VisualizationStep.BATCH) + self.padding = 3 + self.font = ImageFont.load_default(size=16) + + def generate(self, **kwargs) -> Iterator[GeneratorResult]: + """Generate images and return them as an iterator.""" + outputs = kwargs.get("outputs", None) + if outputs is None: + msg = "Outputs must be provided to generate images." + raise ValueError(msg) + return self._visualize_batch(outputs) + + def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]: + """Visualize batch of images.""" + batch_size = batch["image"].shape[0] + height, width = batch["image"].shape[-2:] + for i in range(batch_size): + image = batch["image"][i] + explanation = batch["explanation"][i] + file_name = Path(batch["image_path"][i]) + image = Image.open(file_name) + image = image.resize((width, height)) + image = self._draw_image(width, height, image=image, explanation=explanation) + yield GeneratorResult(image=image, file_name=file_name) + + def _draw_image(self, width: int, height: int, image: Image, explanation: str) -> np.ndarray: + text_canvas: Image = self._get_explanation_image(width, height, image, explanation) + label_canvas: Image = self._get_label_image(explanation) + + final_width = max(text_canvas.size[0], width) + final_height = height + text_canvas.size[1] + combined_image = Image.new("RGB", (final_width, final_height), (255, 255, 255)) + combined_image.paste(image, (self.padding, 0)) + combined_image.paste(label_canvas, (10, 10)) + combined_image.paste(text_canvas, (0, height)) + return np.array(combined_image) + + def _get_label_image(self, explanation: str) -> Image: + # Draw label + # Can't use pred_labels as it is computed from the pred_scores using image_threshold. It gives incorrect value. + # So, using explanation. This will probably change with the new design. + label = "Anomalous" if explanation.startswith("Y") else "Normal" + label_color = "red" if label == "Anomalous" else "green" + label_canvas = Image.new("RGB", (100, 20), color=label_color) + draw = ImageDraw.Draw(label_canvas) + draw.text((0, 0), label, font=self.font, fill="white", align="center") + return label_canvas + + def _get_explanation_image(self, width: int, height: int, image: Image, explanation: str) -> Image: + # compute wrap width + text_canvas = Image.new("RGB", (width, height), color="white") + dummy_image = ImageDraw.Draw(image) + text_bbox = dummy_image.textbbox((0, 0), explanation, font=self.font, align="center") + text_canvas_width = text_bbox[2] - text_bbox[0] + self.padding + + # split lines based on the width + lines = list(explanation.split("\n")) + line_with_max_len = max(lines, key=len) + new_width = int(width * len(line_with_max_len) // text_canvas_width) + + # wrap text based on the new width + lines = [] + current_line: list[str] = [] + for word in explanation.split(" "): + test_line = " ".join([*current_line, word]) + if len(test_line) <= new_width: + current_line.append(word) + else: + lines.append(" ".join(current_line)) + current_line = [word] + lines.append(" ".join(current_line)) + wrapped_lines = "\n".join(lines) + + # recompute height + dummy_image = Image.new("RGB", (new_width, height), color="white") + draw = ImageDraw.Draw(dummy_image) + text_bbox = draw.textbbox((0, 0), wrapped_lines, font=self.font, align="center") + new_width = int(text_bbox[2] - text_bbox[0] + self.padding) + new_height = int(text_bbox[3] - text_bbox[1] + self.padding) + + # Final text image + text_canvas = Image.new("RGB", (new_width, new_height), color="white") + draw = ImageDraw.Draw(text_canvas) + draw.text((self.padding // 2, 0), wrapped_lines, font=self.font, fill="black", align="center") + return text_canvas diff --git a/tests/integration/model/test_models.py b/tests/integration/model/test_models.py index e743cd52f2..eea3d88e66 100644 --- a/tests/integration/model/test_models.py +++ b/tests/integration/model/test_models.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from unittest.mock import MagicMock import pytest @@ -179,7 +180,7 @@ def _get_objects( # select task type if model_name in {"rkde", "ai_vad"}: task_type = TaskType.DETECTION - elif model_name in {"ganomaly", "dfkde"}: + elif model_name in {"ganomaly", "dfkde", "vlm_ad"}: task_type = TaskType.CLASSIFICATION else: task_type = TaskType.SEGMENTATION @@ -209,6 +210,11 @@ def _get_objects( ) model = get_model(model_name, **extra_args) + + if model_name == "vlm_ad": + model.vlm_backend = MagicMock() + model.vlm_backend.predict.return_value = "YES: Because reasons..." + engine = Engine( logger=False, default_root_dir=project_path,