Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/Add-unit-testing-to-KIN' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaminyou committed Jul 7, 2022
2 parents 0991c8c + 6344364 commit d002e57
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 84 deletions.
88 changes: 5 additions & 83 deletions models/kin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch
import torch.nn as nn
from utils.util import get_kernel
Expand Down Expand Up @@ -26,6 +25,7 @@ def __init__(self, out_channels=None, affine=True, device="cuda"):
).to(device)

def init_collection(self, y_anchor_num, x_anchor_num):
# TODO: y_anchor_num => grid_height, x_anchor_num => grid_width
self.y_anchor_num = y_anchor_num
self.x_anchor_num = x_anchor_num
self.mean_table = torch.zeros(
Expand All @@ -40,12 +40,15 @@ def init_collection(self, y_anchor_num, x_anchor_num):
)

def init_kernel(self, kernel_padding, kernel_mode):
# TODO: 1. Consider to use strategy pattern
# TODO: 2. padding => kernel_size, and raise an error for even number
kernel = get_kernel(padding=kernel_padding, mode=kernel_mode)
self.kernel = kernel.to(self.device)

def pad_table(self, padding):
# modify
# padded table shape inconsisency
# TODO: Don't permute the dimensions
pad_func = nn.ReplicationPad2d((padding, padding, padding, padding))
self.padded_mean_table = pad_func(
self.mean_table.permute(2, 0, 1).unsqueeze(0)
Expand All @@ -60,6 +63,7 @@ def forward_normal(self, x):
return x

def forward(self, x, y_anchor=None, x_anchor=None, padding=1):
# TODO: Do not reply on self.training
if self.training or self.normal_instance_normalization:
return self.forward_normal(x)

Expand Down Expand Up @@ -133,85 +137,3 @@ def use_kernelized_instance_norm(model, padding=1):
layer.pad_table(padding=padding)
layer.collection_mode = False
layer.normal_instance_normalization = False


"""
USAGE
support a dataset with a dataloader would return
(x, y_anchor, x_anchor) each time
kin = KernelizedInstanceNorm()
[TRAIN] anchors are not used during training
kin.train()
for (x, _, _) in dataloader:
kin(x)
[COLLECT] anchors are required and any batch size is allowed
kin.eval()
init_kernelized_instance_norm(
kin, y_anchor_num=$y_anchor_num,
x_anchor_num=$x_anchor_num,
kernel_padding=$kernel_padding,
kernel_mode=$kernel_mode,
)
for (x, y_anchor, x_anchor) in dataloader:
kin(x, y_anchor=y_anchor, x_anchor=x_anchor)
[INFERENCE] anchors are required and batch size is limited to 1 !!
kin.eval()
use_kernelized_instance_norm(kin, kernel_padding=$kernel_padding)
for (x, y_anchor, x_anchor) in dataloader:
kin(x, y_anchor=y_anchor, x_anchor=x_anchor, padding=$padding)
[INFERENCE WITH NORMAL INSTANCE NORMALIZATION] anchors are not required
kin.eval()
not_use_kernelized_instance_norm(kin)
for (x, _, _) in dataloader:
kin(x)
"""

if __name__ == "__main__":
import itertools

from torch.utils.data import DataLoader, Dataset

class TestDataset(Dataset):
def __init__(self, y_anchor_num=10, x_anchor_num=10):
self.y_anchor_num = y_anchor_num
self.x_anchor_num = x_anchor_num
self.anchors = list(
itertools.product(
np.arange(0, y_anchor_num), np.arange(0, x_anchor_num)
)
)

def __len__(self):
return len(self.anchors)

def __getitem__(self, idx):
x = torch.randn(3, 512, 512)
y_anchor, x_anchor = self.anchors[idx]
return (x, y_anchor, x_anchor)

test_dataset = TestDataset()
test_dataloader = DataLoader(test_dataset, batch_size=5)

kin = KernelizedInstanceNorm(out_channels=3, device="cpu")
kin.eval()
init_kernelized_instance_norm(
kin,
y_anchor_num=10,
x_anchor_num=10,
kernel_padding=1,
kernel_mode="constant",
)

for (x, y_anchor, x_anchor) in test_dataloader:
kin(x, y_anchor=y_anchor, x_anchor=x_anchor)

use_kernelized_instance_norm(kin, kernel_padding=1)
test_dataloader = DataLoader(test_dataset, batch_size=1)
for (x, y_anchor, x_anchor) in test_dataloader:
x = kin(x, y_anchor=y_anchor, x_anchor=x_anchor, padding=1)
print(x.shape)
Empty file added models/tests/__init__.py
Empty file.
141 changes: 141 additions & 0 deletions models/tests/test_kin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import numpy as np
import pytest
import torch

from ..kin import KernelizedInstanceNorm


def normalize(x):
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True)
return (x - mean) / std


def test_forward_normal():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu')
x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
x = torch.FloatTensor(x)

expected = normalize(x)

check = layer.forward_normal(torch.FloatTensor(x))

assert check.numpy() == pytest.approx(expected, abs=1e-6)


def test_init_kernel():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu')
layer.init_kernel(kernel_padding=1, kernel_mode='constant')

expected = np.ones(shape=(3, 3), dtype=np.float32) / 9

assert layer.kernel.numpy() == pytest.approx(expected)


def test_init_collection():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu')
layer.init_collection(y_anchor_num=10, x_anchor_num=9)

expected_mean_table = np.zeros(shape=(10, 9, 3))
expected_std_table = np.zeros(shape=(10, 9, 3))

np.testing.assert_array_equal(layer.mean_table.numpy(), expected_mean_table)
np.testing.assert_array_equal(layer.std_table.numpy(), expected_std_table)


def test_pad_table():
layer = KernelizedInstanceNorm(out_channels=1, device='cpu')

table = np.array(
[
[0, 1],
[2, 3],
],
dtype=np.float32
).reshape(2, 2, 1)

expected_table = np.array(
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3],
],
dtype=np.float32
).reshape(4, 4, 1)

layer.mean_table = torch.FloatTensor(table)
layer.std_table = torch.FloatTensor(table)

layer.pad_table(padding=1)

expected_padded_mean_table = expected_table.transpose(2, 0, 1).reshape(1, 1, 4, 4)
expected_padded_std_table = expected_table.transpose(2, 0, 1).reshape(1, 1, 4, 4)

np.testing.assert_array_equal(layer.padded_mean_table.numpy(), expected_padded_mean_table)
np.testing.assert_array_equal(layer.padded_std_table.numpy(), expected_padded_std_table)


def test_forward_with_normal_instance_normalization():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu')
layer.normal_instance_normalization = True
x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
x = torch.FloatTensor(x)

expected = normalize(x)

check = layer.forward_normal(torch.FloatTensor(x))

assert check.numpy() == pytest.approx(expected, abs=1e-6)


def test_forward_with_collection_mode():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu').eval()
layer.collection_mode = True
layer.normal_instance_normalization = False

layer.init_collection(y_anchor_num=3, x_anchor_num=3)
layer.init_kernel(kernel_padding=1, kernel_mode='constant')

x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
x = torch.FloatTensor(x)

std, mean = torch.std_mean(x, dim=(2, 3))

expected_mean_table = np.zeros(shape=(3, 3, 3), dtype=np.float32)
expected_std_table = np.zeros(shape=(3, 3, 3), dtype=np.float32)

expected_mean_table[0, 0] = mean
expected_std_table[0, 0] = std

check = layer.forward(x, x_anchor=0, y_anchor=0, padding=1)

assert check.detach().numpy() == pytest.approx(normalize(x).numpy(), abs=1e-6)
assert layer.mean_table.numpy() == pytest.approx(expected_mean_table)
assert layer.std_table.numpy() == pytest.approx(expected_std_table)


def test_forward_with_kernelized():
layer = KernelizedInstanceNorm(out_channels=3, device='cpu').eval()
layer.collection_mode = True
layer.normal_instance_normalization = False

layer.init_collection(y_anchor_num=3, x_anchor_num=3)
layer.init_kernel(kernel_padding=1, kernel_mode='constant')

x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
x = torch.FloatTensor(x)

layer.forward(x, x_anchor=1, y_anchor=1, padding=1)

layer.collection_mode = False
layer.pad_table(1)

check = layer.forward(x, x_anchor=1, y_anchor=1, padding=1)
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True)

mean /= 9
std /= 9

expected = (x - mean) / std

assert check.detach().numpy() == pytest.approx(expected)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ numpy==1.19.5
Pillow==8.1.2
PyYAML==5.4.1
torch==1.7.0
torchvision==0.8.0
torchvision==0.8.0
pytest

0 comments on commit d002e57

Please sign in to comment.