-
Notifications
You must be signed in to change notification settings - Fork 4
/
main_train.py
101 lines (72 loc) · 3.62 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from get_opts import *
from nerf.provider import NeRFDataset
from nerf.gui import NeRFGUI
from nerf.trainer import *
import os
pienerf_dir = os.path.dirname(os.path.abspath(__file__))
#torch.autograd.set_detect_anomaly(True)
os.environ['KMP_DUPLICATE_LIB_OK'] = '1'
if __name__ == '__main__':
if not os.path.exists(pienerf_dir + "/model"):
os.mkdir(pienerf_dir + "/model")
parser = argparse.ArgumentParser()
opt = get_shared_opts(parser)
opt.workspace = pienerf_dir + "/model/" + opt.path.split("/")[-1]
if opt.ff:
opt.fp16 = True
assert opt.bg_radius <= 0, "background model is not implemented for --ff"
from nerf.network_ff import NeRFNetwork
elif opt.tcnn:
opt.fp16 = True
assert opt.bg_radius <= 0, "background model is not implemented for --tcnn"
from nerf.network_tcnn import NeRFNetwork
else:
from nerf.network import NeRFNetwork
seed_everything(opt.seed)
model = NeRFNetwork( # NeRFNetwork inherits NeRFRenderer
encoding="hashgrid",
bound=opt.bound,
cuda_ray=opt.cuda_ray,
density_scale=1,
min_near=opt.min_near,
density_thresh=opt.density_thresh,
bg_radius=opt.bg_radius,
)
print(model)
criterion = torch.nn.MSELoss(reduction='none')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opt.test:
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)
if opt.gui:
gui = NeRFGUI(opt, trainer)
gui.render()
else:
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
if test_loader.has_gt:
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
# trainer.test(test_loader, write_video=True) # test and save video
trainer.save_mesh(resolution=256, threshold=10)
else:
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
train_loader = NeRFDataset(opt, device=device, type='train').dataloader()
# decay to 0.1 * init_lr at last iter step
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer,
criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler,
scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=50)
if opt.gui:
gui = NeRFGUI(opt, trainer, train_loader)
gui.render()
else:
valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader()
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
trainer.train(train_loader, valid_loader, max_epoch)####
# also test
test_loader = NeRFDataset(opt, device=device, type='test').dataloader()
if test_loader.has_gt:
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
# trainer.test(test_loader, write_video=True) # test and save video
# trainer.save_mesh(resolution=256, threshold=10)
trainer.save_point_cloud(resolution=256)