forked from tengyu1998/SCI
-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathtest.py
62 lines (48 loc) · 1.93 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import sys
import numpy as np
import torch
import argparse
import torch.utils
import torch.backends.cudnn as cudnn
from PIL import Image
from torch.autograd import Variable
from model import Finetunemodel
from multi_read_data import MemoryFriendlyLoader
parser = argparse.ArgumentParser("SCI")
parser.add_argument('--data_path', type=str, default='./data/medium',
help='location of the data corpus')
parser.add_argument('--save_path', type=str, default='./results/medium', help='location of the data corpus')
parser.add_argument('--model', type=str, default='./weights/medium.pt', help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)
TestDataset = MemoryFriendlyLoader(img_dir=args.data_path, task='test')
test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1,
pin_memory=True, num_workers=0)
def save_images(tensor, path):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
im.save(path, 'png')
def main():
if not torch.cuda.is_available():
print('no gpu device available')
sys.exit(1)
model = Finetunemodel(args.model)
model = model.cuda()
model.eval()
with torch.no_grad():
for _, (input, image_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
image_name = image_name[0].split('\\')[-1].split('.')[0]
i, r = model(input)
u_name = '%s.png' % (image_name)
print('processing {}'.format(u_name))
u_path = save_path + '/' + u_name
save_images(r, u_path)
if __name__ == '__main__':
main()