Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🧪 Add tests for tools #1069

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pre-commit
pytest
pytest-cov
pytest-order
pytest-sugar
pytest-xdist
coverage[toml]
Expand Down
8 changes: 8 additions & 0 deletions tests/pre_merge/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Test entry point scripts.


Note: This might be removed when migration to CLI is complete.
"""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
56 changes: 56 additions & 0 deletions tests/pre_merge/tools/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Fixtures for the tools tests."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from tempfile import TemporaryDirectory
from typing import Generator

import albumentations as A
import cv2
import numpy as np
import pytest
from albumentations.pytorch import ToTensorV2

from anomalib.config import get_configurable_parameters


@pytest.fixture(scope="package")
def project_path() -> Generator[str, None, None]:
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
with TemporaryDirectory() as project_dir:
yield project_dir


@pytest.fixture(scope="package")
def get_config(project_path):
def get_config(
model_name: str | None = None,
config_path: str | None = None,
weight_file: str | None = None,
):
"""Gets config for testing."""
config = get_configurable_parameters(model_name, config_path, weight_file)
config.dataset.image_size = (100, 100)
config.model.input_size = (100, 100)
config.project.path = project_path
config.trainer.max_epochs = 1
config.trainer.check_val_every_n_epoch = 1
config.trainer.limit_train_batches = 1
config.trainer.limit_predict_batches = 1
return config

yield get_config


@pytest.fixture(scope="package")
def get_dummy_inference_image(project_path) -> Generator[str, None, None]:
image = np.zeros((100, 100, 3), dtype=np.uint8)
cv2.imwrite(project_path + "/dummy_image.png", image)
yield project_path + "/dummy_image.png"


@pytest.fixture(scope="package")
def transforms_config() -> dict:
"""Note: this is computed using trainer.datamodule.test_data.transform.to_dict()"""
return A.Compose([A.ToFloat(max_value=255), ToTensorV2()]).to_dict()
93 changes: 93 additions & 0 deletions tests/pre_merge/tools/test_gradio_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Test Gradio inference entrypoint script."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import sys
from importlib.util import find_spec

import pytest

from anomalib.data import TaskType
from anomalib.deploy import ExportMode, OpenVINOInferencer, TorchInferencer, export
from anomalib.models import get_model

sys.path.append("tools/inference")


@pytest.mark.order(5)
class TestGradioInferenceEntrypoint:
"""This tests whether the entrypoints run without errors without quantitative measure of the outputs.

Note: This does not launch the gradio server. It only checks if the right inferencer is called.
"""

@pytest.fixture
def get_functions(self):
"""Get functions from Gradio_inference.py"""
if find_spec("gradio_inference") is not None:
from tools.inference.gradio_inference import get_inferencer, get_parser
else:
raise Exception("Unable to import gradio_inference.py for testing")
return get_parser, get_inferencer

def test_torch_inference(self, get_functions, project_path, get_config, transforms_config):
"""Test gradio_inference.py"""
parser, inferencer = get_functions
model = get_model(get_config("padim"))

# export torch model
export(
task=TaskType.SEGMENTATION,
transform=transforms_config,
input_size=(100, 100),
model=model,
export_mode=ExportMode.TORCH,
export_root=project_path,
)

arguments = parser().parse_args(
[
"--weights",
project_path + "/weights/torch/model.pt",
]
)
assert isinstance(inferencer(arguments.weights, arguments.metadata), TorchInferencer)

def test_openvino_inference(self, get_functions, project_path, get_config, transforms_config):
"""Test gradio_inference.py"""
parser, inferencer = get_functions
model = get_model(get_config("padim"))

# export OpenVINO model
export(
task=TaskType.SEGMENTATION,
transform=transforms_config,
input_size=(100, 100),
model=model,
export_mode=ExportMode.OPENVINO,
export_root=project_path,
)

arguments = parser().parse_args(
[
"--weights",
project_path + "/weights/openvino/model.bin",
"--metadata",
project_path + "/weights/openvino/metadata.json",
]
)
assert isinstance(inferencer(arguments.weights, arguments.metadata), OpenVINOInferencer)

# test error is raised when metadata is not provided to openvino model
with pytest.raises(ValueError):
arguments = parser().parse_args(
[
"--weights",
project_path + "/weights/openvino/model.bin",
]
)
inferencer(arguments.weights, arguments.metadata)
47 changes: 47 additions & 0 deletions tests/pre_merge/tools/test_lightning_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test lightning inference entrypoint script."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import sys
from importlib.util import find_spec

import pytest

sys.path.append("tools/inference")
from unittest.mock import patch


@pytest.mark.order(3)
class TestLightningInferenceEntrypoint:
"""This tests whether the entrypoints run without errors without quantitative measure of the outputs."""

@pytest.fixture
def get_functions(self):
"""Get functions from lightning_inference.py"""
if find_spec("lightning_inference") is not None:
from tools.inference.lightning_inference import get_parser, infer
else:
raise Exception("Unable to import lightning_inference.py for testing")
return get_parser, infer

def test_lightning_inference(self, get_functions, get_config, project_path, get_dummy_inference_image):
"""Test lightning_inferenc.py"""
get_parser, infer = get_functions
with patch("tools.inference.lightning_inference.get_configurable_parameters", side_effect=get_config):
arguments = get_parser().parse_args(
[
"--config",
"src/anomalib/models/padim/config.yaml",
"--weights",
project_path + "/weights/lightning/model.ckpt",
"--input",
get_dummy_inference_image,
"--output",
project_path + "/output",
]
)
infer(arguments)
66 changes: 66 additions & 0 deletions tests/pre_merge/tools/test_openvino_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Test OpenVINO inference entrypoint script."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import sys
from importlib.util import find_spec

import pytest

from anomalib.data import TaskType
from anomalib.deploy import ExportMode, export
from anomalib.models import get_model

sys.path.append("tools/inference")


@pytest.mark.order(4)
class TestOpenVINOInferenceEntrypoint:
"""This tests whether the entrypoints run without errors without quantitative measure of the outputs."""

@pytest.fixture
def get_functions(self):
"""Get functions from openvino_inference.py"""
if find_spec("openvino_inference") is not None:
from tools.inference.openvino_inference import get_parser, infer
else:
raise Exception("Unable to import openvino_inference.py for testing")
return get_parser, infer

def test_openvino_inference(
self, get_functions, get_config, project_path, get_dummy_inference_image, transforms_config
):
"""Test openvino_inference.py"""
get_parser, infer = get_functions

model = get_model(get_config("padim"))

# export OpenVINO model
export(
task=TaskType.SEGMENTATION,
transform=transforms_config,
input_size=(100, 100),
model=model,
export_mode=ExportMode.OPENVINO,
export_root=project_path,
)

arguments = get_parser().parse_args(
[
"--config",
"src/anomalib/models/padim/config.yaml",
"--weights",
project_path + "/weights/openvino/model.bin",
"--metadata",
project_path + "/weights/openvino/metadata.json",
"--input",
get_dummy_inference_image,
"--output",
project_path + "/output",
]
)
infer(arguments)
54 changes: 54 additions & 0 deletions tests/pre_merge/tools/test_test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test test.py entrypoint script."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import sys
from importlib.util import find_spec

import pytest

sys.path.append("tools")
from unittest.mock import patch


@pytest.mark.order(2)
class TestTestEntrypoint:
"""This tests whether the entrypoints run without errors without quantitative measure of the outputs."""

@pytest.fixture
def get_functions(self):
"""Get functions from test.py"""
if find_spec("test") is not None:
from tools.test import get_parser, test
else:
raise Exception("Unable to import test.py for testing")
return get_parser, test

def test_test(self, get_functions, get_config, project_path):
"""Test test.py"""
get_parser, test = get_functions
with patch("tools.test.get_configurable_parameters", side_effect=get_config):
# Test when model key is passed
arguments = get_parser().parse_args(
["--model", "padim", "--weight_file", project_path + "/weights/lightning/model.ckpt"]
)
test(arguments)

# Test when weight file is incorrect
arguments = get_parser().parse_args(["--model", "padim"])
with pytest.raises(FileNotFoundError):
test(arguments)

# Test when config key is passed
arguments = get_parser().parse_args(
[
"--config",
"src/anomalib/models/padim/config.yaml",
"--weight_file",
project_path + "/weights/lightning/model.ckpt",
]
)
test(arguments)
58 changes: 58 additions & 0 deletions tests/pre_merge/tools/test_torch_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Test torch inference entrypoint script."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from __future__ import annotations

import sys
from importlib.util import find_spec

import pytest

from anomalib.data import TaskType
from anomalib.deploy import ExportMode, export
from anomalib.models import get_model

sys.path.append("tools/inference")


@pytest.mark.order(4)
class TestTorchInferenceEntrypoint:
"""This tests whether the entrypoints run without errors without quantitative measure of the outputs."""

@pytest.fixture
def get_functions(self):
"""Get functions from torch_inference.py"""
if find_spec("torch_inference") is not None:
from tools.inference.torch_inference import get_parser, infer
else:
raise Exception("Unable to import torch_inference.py for testing")
return get_parser, infer

def test_torch_inference(
self, get_functions, get_config, project_path, get_dummy_inference_image, transforms_config
):
"""Test torch_inference.py"""
get_parser, infer = get_functions
model = get_model(get_config("padim"))
export(
task=TaskType.SEGMENTATION,
transform=transforms_config,
input_size=(100, 100),
model=model,
export_mode=ExportMode.TORCH,
export_root=project_path,
)
arguments = get_parser().parse_args(
[
"--weights",
project_path + "/weights/torch/model.pt",
"--input",
get_dummy_inference_image,
"--output",
project_path + "/output",
]
)
infer(arguments)
Loading