-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtest_view_interp.py
93 lines (77 loc) · 3.49 KB
/
test_view_interp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import print_function, division
import argparse
import logging
import numpy as np
import cv2
import os
from pathlib import Path
from tqdm import tqdm
from lib.human_loader import StereoHumanDataset
from lib.network import RtStereoHumanModel
from config.stereo_human_config import ConfigStereoHuman as config
from lib.utils import get_novel_calib
from lib.GaussianRender import pts2render
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class StereoHumanRender:
def __init__(self, cfg_file, phase):
self.cfg = cfg_file
self.bs = self.cfg.batch_size
self.model = RtStereoHumanModel(self.cfg, with_gs_render=True)
self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase)
self.model.cuda()
if self.cfg.restore_ckpt:
self.load_ckpt(self.cfg.restore_ckpt)
self.model.eval()
def infer_static(self, view_select, novel_view_nums):
total_samples = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img')))
for idx in tqdm(range(total_samples)):
item = self.dataset.get_test_item(idx, source_id=view_select)
data = self.fetch_data(item)
for i in range(novel_view_nums):
ratio_tmp = (i+0.5)*(1/novel_view_nums)
data_i = get_novel_calib(data, self.cfg.dataset, ratio=ratio_tmp, intr_key='intr_ori', extr_key='extr_ori')
with torch.no_grad():
data_i, _, _ = self.model(data_i, is_train=False)
data_i = pts2render(data_i, bg_color=self.cfg.dataset.bg_color)
render_novel = self.tensor2np(data['novel_view']['img_pred'])
cv2.imwrite(self.cfg.test_out_path + '/%s_novel%s.jpg' % (data_i['name'], str(i).zfill(2)), render_novel)
def tensor2np(self, img_tensor):
img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy()
img_np = img_np * 255
img_np = img_np[:, :, ::-1].astype(np.uint8)
return img_np
def fetch_data(self, data):
for view in ['lmain', 'rmain']:
for item in data[view].keys():
data[view][item] = data[view][item].cuda().unsqueeze(0)
return data
def load_ckpt(self, load_path):
assert os.path.exists(load_path)
logging.info(f"Loading checkpoint from {load_path} ...")
ckpt = torch.load(load_path, map_location='cuda')
self.model.load_state_dict(ckpt['network'], strict=True)
logging.info(f"Parameter loading done")
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('--test_data_root', type=str, required=True)
parser.add_argument('--ckpt_path', type=str, required=True)
parser.add_argument('--novel_view_nums', type=int, default=5)
arg = parser.parse_args()
cfg = config()
cfg_for_train = os.path.join('./config', 'stage2.yaml')
cfg.load(cfg_for_train)
cfg = cfg.get_cfg()
cfg.defrost()
cfg.batch_size = 1
cfg.dataset.test_data_root = arg.test_data_root
cfg.dataset.use_processed_data = False
cfg.restore_ckpt = arg.ckpt_path
cfg.test_out_path = './interp_out'
Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True)
cfg.freeze()
render = StereoHumanRender(cfg, phase='test')
render.infer_static(view_select=[0, 1], novel_view_nums=arg.novel_view_nums)