-
Notifications
You must be signed in to change notification settings - Fork 47
/
inference.py
105 lines (83 loc) · 3.93 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
import argparse
import cv2
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from src.model import SODModel
from src.dataloader import InfDataloader, SODLoader
def parse_arguments():
parser = argparse.ArgumentParser(description='Parameters to train your model.')
parser.add_argument('--imgs_folder', default='./data/DUTS/DUTS-TE/DUTS-TE-Image', help='Path to folder containing images', type=str)
parser.add_argument('--model_path', default='/home/tarasha/Projects/sairajk/saliency/SOD_2/models/0.7_wbce_w0-1_w1-1.12/best_epoch-138_acc-0.9107_loss-0.1300.pt', help='Path to model', type=str)
parser.add_argument('--use_gpu', default=True, help='Whether to use GPU or not', type=bool)
parser.add_argument('--img_size', default=256, help='Image size to be used', type=int)
parser.add_argument('--bs', default=24, help='Batch Size for testing', type=int)
return parser.parse_args()
def run_inference(args):
# Determine device
if args.use_gpu and torch.cuda.is_available():
device = torch.device(device='cuda')
else:
device = torch.device(device='cpu')
# Load model
model = SODModel()
chkpt = torch.load(args.model_path, map_location=device)
model.load_state_dict(chkpt['model'])
model.to(device)
model.eval()
inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size)
# Since the images would be displayed to the user, the batch_size is set to 1
# Code at later point is also written assuming batch_size = 1, so do not change
inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2)
print("Press 'q' to quit.")
with torch.no_grad():
for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1):
img_tor = img_tor.to(device)
pred_masks, _ = model(img_tor)
# Assuming batch_size = 1
img_np = np.squeeze(img_np.numpy(), axis=0)
img_np = img_np.astype(np.uint8)
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1))
print('Image :', batch_idx)
cv2.imshow('Input Image', img_np)
cv2.imshow('Generated Saliency Mask', pred_masks_raw)
cv2.imshow('Rounded-off Saliency Mask', pred_masks_round)
key = cv2.waitKey(0)
if key == ord('q'):
break
def calculate_mae(args):
# Determine device
if args.use_gpu and torch.cuda.is_available():
device = torch.device(device='cuda')
else:
device = torch.device(device='cpu')
# Load model
model = SODModel()
chkpt = torch.load(args.model_path, map_location=device)
model.load_state_dict(chkpt['model'])
model.to(device)
model.eval()
test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size)
test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2)
# List to save mean absolute error of each image
mae_list = []
with torch.no_grad():
for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1):
inp_imgs = inp_imgs.to(device)
gt_masks = gt_masks.to(device)
pred_masks, _ = model(inp_imgs)
mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
mae_list.extend(mae)
print('MAE for the test set is :', np.mean(mae_list))
if __name__ == '__main__':
rt_args = parse_arguments()
calculate_mae(rt_args)
run_inference(rt_args)