-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
56 lines (48 loc) · 2.89 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import argparse
import os
import shutil
import torch
from PIL import Image, ImageDraw
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test CGD')
parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str,
help='query image name')
parser.add_argument('--data_base', default='car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth',
type=str, help='queried database')
parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number')
opt = parser.parse_args()
query_img_name, data_base_name, retrieval_num = opt.query_img_name, opt.data_base, opt.retrieval_num
data_name = data_base_name.split('_')[0]
data_base = torch.load('results/{}'.format(data_base_name), map_location=torch.device('cpu'))
if query_img_name not in data_base['test_images']:
raise FileNotFoundError('{} not found'.format(query_img_name))
query_index = data_base['test_images'].index(query_img_name)
query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
query_label = torch.tensor(data_base['test_labels'][query_index])
query_feature = data_base['test_features'][query_index]
gallery_images = data_base['{}_images'.format('test' if data_name != 'isc' else 'gallery')]
gallery_labels = torch.tensor(data_base['{}_labels'.format('test' if data_name != 'isc' else 'gallery')])
gallery_features = data_base['{}_features'.format('test' if data_name != 'isc' else 'gallery')]
dist_matrix = torch.cdist(query_feature.unsqueeze(0).unsqueeze(0), gallery_features.unsqueeze(0)).squeeze()
if data_name != 'isc':
dist_matrix[query_index] = float('inf')
idx = dist_matrix.topk(k=retrieval_num, dim=-1, largest=False)[1]
result_dir = os.path.splitext(os.path.basename(opt.data_base))[0]
result_path = 'results/{}/{}'.format(result_dir, query_img_name.split('/')[-1].split('.')[0])
print("Save test result in result path: ", result_path)
if os.path.exists(result_path):
shutil.rmtree(result_path)
os.makedirs(result_path, exist_ok=True)
query_image.save('{}/query_img.jpg'.format(result_path))
for num, index in enumerate(idx):
retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \
.resize((224, 224), resample=Image.BILINEAR)
draw = ImageDraw.Draw(retrieval_image)
retrieval_label = gallery_labels[index.item()]
retrieval_status = (retrieval_label == query_label).item()
retrieval_dist = dist_matrix[index.item()].item()
if retrieval_status:
draw.rectangle((0, 0, 223, 223), outline='green', width=8)
else:
draw.rectangle((0, 0, 223, 223), outline='red', width=8)
retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_dist))