forked from MarcoForte/FBA_Matting
-
Notifications
You must be signed in to change notification settings - Fork 2
/
demo.py
112 lines (82 loc) · 4.05 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Our libs
from networks.transforms import trimap_transform, groupnorm_normalise_image
from networks.models import build_model
from dataloader import PredDataset
# System libs
import os
import argparse
# External libs
import cv2
import numpy as np
import torch
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
def np_to_torch(x):
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().to(device)
def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
''' Scales inputs to multiple of 8. '''
h, w = x.shape[:2]
h1 = int(np.ceil(scale * h / 8) * 8)
w1 = int(np.ceil(scale * w / 8) * 8)
x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
return x_scale
def predict_fba_folder(model, args):
save_dir = args.output_dir
dataset_test = PredDataset(args.image_dir, args.trimap_dir)
gen = iter(dataset_test)
for item_dict in gen:
image_np = item_dict['image']
trimap_np = item_dict['trimap']
fg, bg, alpha = pred(image_np, trimap_np, model)
cv2.imwrite(os.path.join(save_dir, 'fg', item_dict['name'][:-4] + '_fg.png'), fg[:, :, ::-1] * 255)
cv2.imwrite(os.path.join(save_dir, 'bg', item_dict['name'][:-4] + '_bg.png'), bg[:, :, ::-1] * 255)
cv2.imwrite(os.path.join(save_dir, 'alpha', item_dict['name'][:-4] + '_alpha.png'), alpha * 255)
def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
''' Predict alpha, foreground and background.
Parameters:
image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
Returns:
fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
'''
h, w = trimap_np.shape[:2]
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
with torch.no_grad():
image_torch = np_to_torch(image_scale_np)
trimap_torch = np_to_torch(trimap_scale_np)
trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')
output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
alpha = output[:, :, 0]
fg = output[:, :, 1:4]
bg = output[:, :, 4:7]
alpha[trimap_np[:, :, 0] == 1] = 0
alpha[trimap_np[:, :, 1] == 1] = 1
fg[alpha == 1] = image_np[alpha == 1]
bg[alpha == 0] = image_np[alpha == 0]
return fg, bg, alpha
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Model related arguments
parser.add_argument('--encoder', default='resnet50_GN_WS', help="encoder model")
parser.add_argument('--decoder', default='fba_decoder', help="Decoder model")
parser.add_argument('--weights', default='FBA.pth')
parser.add_argument('--image_dir', default='./examples/images', help="")
parser.add_argument('--trimap_dir', default='./examples/trimaps', help="")
# parser.add_argument('--image_dir', default='./Dataset/alphamatting.com/input_training_lowres', help="")
# parser.add_argument('--trimap_dir', default='./Dataset/alphamatting.com/trimap_training_lowres/Trimap1', help="")
parser.add_argument('--output_dir', default='./examples/predictions/', help="")
args = parser.parse_args()
print("Build Model")
model = build_model(args)
# print(model)
# model.eval()
# predict_fba_folder(model, args)
# # 测试代码,必须要在device:0上使用
# from torchsummary import summary
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model.to(device)
# summary(model, [(3, 1920, 1080), (2, 1920, 1080), (3, 1920, 1080), (6, 1920, 1080)])