-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_inference.py
executable file
·101 lines (82 loc) · 4.22 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
from path import Path
import torch
import torch.backends.cudnn as cudnn
import models
from tqdm import tqdm
import torchvision.transforms as transformsW`
import flow_transforms
from scipy.ndimage import imread
from scipy.misc import imsave
import numpy as np
from main import flow2rgb
import sys
partial
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__"))
parser = argparse.ArgumentParser(description='PyTorch FlowNet inference on a folder of img pairs',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('data', metavar='DIR',
help='path to images folder, image names must match \'[name]0.[ext]\' and \'[name]1.[ext]\'')
parser.add_argument('pretrained', metavar='PTH', help='path to pre-trained model')
parser.add_argument('--output', metavar='DIR', default=None,
help='path to output folder. If not set, will be created in data folder')
parser.add_argument('--div-flow', default=20, type=float,
help='value by which flow will be divided. overwritten if stored in pretrained file')
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
parser.add_argument('--max_flow', default=None, type=float,
help='max flow value. Flow map color is saturated above this value. If not set, will use flow map\'s max value')
parser.add_argument('--upsampling', '-u', choices=['nearest', 'bilinear'], default=None, help='if not set, will output FlowNet raw input,'
'which is 4 times downsampled. If set, will output full resolution flow map, with selected upsampling')
parser.add_argument('--bidirectional', action='store_true', help='if set, will output invert flow (from 1 to 0) along with regular flow')
@torch.no_grad()
def main():
global args, save_path
args = parser.parse_args()
data_dir = Path(args.data)
print("=> fetching img pairs in '{}'".format(args.data))
if args.output is not None:
save_path = Path(args.output)/'flow'
else: save_path = data_dir/'flow'
print('=> will save everything to {}'.format(save_path))
save_path.makedirs_p()
# Data loading code
input_transform = transforms.Compose([
flow_transforms.ArrayToTensor(),
transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])
img_pairs = []
for ext in args.img_exts:
test_files = data_dir.files('*0.{}'.format(ext))
for file in test_files:
img_pair = file.parent / (file.namebase[:-1] + '1.{}'.format(ext))
if img_pair.isfile():
img_pairs.append([file, img_pair])
print('{} samples found'.format(len(img_pairs)))
# create model
network_data = torch.load(args.pretrained)
print("=> using pre-trained model '{}'".format(network_data['arch']))
model = models.__dict__[network_data['arch']](network_data).cuda()
model.eval()
cudnn.benchmark = True
if 'div_flow' in network_data.keys():
args.div_flow = network_data['div_flow']
for (img1_file, img2_file) in tqdm(img_pairs):
img1 = input_transform(imread(img1_file))
img2 = input_transform(imread(img2_file))
input_var = torch.tensor(torch.cat([img1, img2]).cuda()).unsqueeze(0)
if args.bidirectional:
# feed inverted pair along with normal pair
inverted_input_var = torch.tensor(torch.cat([img2, img1],0).cuda()).unsqueeze(0)
input_var = torch.cat([input_var, inverted_input_var])
# compute output
output = model(input_var)
if args.upsampling is not None:
output = torch.nn.functional.upsample(output, size=img1.size()[-2:], mode=args.upsampling)
for suffix, flow_output in zip(['flow', 'inv_flow'], output.data.cpu()):
rgb_flow = flow2rgb(args.div_flow * flow_output, max_value=args.max_flow)
to_save = (rgb_flow * 255).astype(np.uint8).transpose(1,2,0)
imsave(save_path/'{}{}.png'.format(img1_file.namebase[:-1], suffix), to_save)
if __name__ == '__main__':
main()