-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_solid.py
159 lines (142 loc) · 7.42 KB
/
test_solid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import open3d as o3d
import os, glob, argparse
import numpy as np
import torch
import MinkowskiEngine as ME
import time
from model.GRNet_solid import knn_multiscale
from util import kdtree_partition,write_ply_ascii_geo,read_ply_ascii_geo,get_points_number
from tool.pc_error import pc_error
from tools import gpcc_encode,gpcc_decode
import pandas as pd
def parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ckpts", default='./ckpts_pretrain/solid/r01/last_epoch.pth', help='Path to pretrained model')
parser.add_argument("--output_dir", type=str, default='./output/solid/', help="Output dir")
parser.add_argument("--gpcc_outdir", type=str, default='./gpcc_out/')
parser.add_argument("--rate_dir", type=str, default='r05/', help="G-PCC(octree) rate point.")
parser.add_argument("--GT_dir",default='/media/ivc-18/disk/testdata/solid_all/9bit/',help='Ground truth point cloud dir')
parser.add_argument("--last_kernel_size", type=int, default=7,help='The final layer kernel size, R01-R02:k=7 , R03-R05:k=5')
parser.add_argument("--resolution", type=int,default=511, help='Follow MPEG CTC ,9bit:511, 10bit:1023, 11bit:2047, 12bit:4095')
parser.add_argument("--posQuantscale", type=float, default=0.125, help='PosQuantscale of G-PCC(octree), R01:0.125, R02:0.25, R03:0.5, R04:0.75, R05:0.875')
parser.add_argument("--max_nums", type=int, default=500000, help='Max points number for kd-tree partition')
args = parser.parse_args()
return args
args = parse_args()
GT_files = glob.glob(args.GT_dir+'*.ply')
GT_files=sorted(GT_files)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = knn_multiscale(last_kernel_size=args.last_kernel_size).to(device)
ckpt = torch.load(args.ckpts,map_location='cuda:0')
max_nums=args.max_nums
model.load_state_dict(ckpt['model'])
gpcc_outdir =args.gpcc_outdir
posQuantscale=args.posQuantscale
l = len(GT_files)
save_path = os.path.join(args.output_dir, args.rate_dir)
if not os.path.exists(save_path):
os.makedirs(save_path)
if not os.path.exists(gpcc_outdir):
os.makedirs(gpcc_outdir)
i=0
for pc_gt in GT_files:
print("################## FILE NUMBER#################: ", i, ' / ', l)
print('model parameter:', sum(param.numel() for param in model.parameters()))
####################### Read point cloud ###########
file_name = os.path.basename(pc_gt)
pc_name=file_name.split('.')[0]
pcd = o3d.io.read_point_cloud(pc_gt)
pcd_gt = o3d.io.read_point_cloud(pc_gt)
coords_gt = np.asarray(pcd_gt.points)
num_gt =len(coords_gt)
####################################################
####################### Partition ##################
parts_gt = kdtree_partition(coords_gt,max_nums)
####################################################
bits_coordinates = 0
run_time=0
out_list = []
gpcc_list=[]
for j,part_gt in enumerate(parts_gt):
partdir='./gpcc_out/test'+'part'+str(j)+'.ply'
write_ply_ascii_geo(partdir,part_gt)
part_T = ME.utils.batched_coordinates([part_gt])
####################### G-PCC(octree) lossy compression ####################
bin_dir = gpcc_outdir + 'part'+str(j) + '.bin'
dec_dir = gpcc_outdir + 'part'+str(j) + '_dec.ply'
results_enc = gpcc_encode(partdir, bin_dir, posQuantscale=posQuantscale)
results_dec = gpcc_decode(bin_dir, dec_dir)
# bpp
num_points = get_points_number(dec_dir)
bpp = round(8 * results_enc['Total bitstream size'] / get_points_number(partdir), 4)
bits_coordinates = 8*results_enc['Total bitstream size']+bits_coordinates
# results
results_gpcc = {'posQuantscale': posQuantscale, 'num_points': num_points, 'bpp': bpp,
'bytes': int(results_enc['Total bitstream size']),
'time (enc)': results_enc['Processing time (user)'],
'time (dec)': results_dec['Processing time (user)']}
part_dec=read_ply_ascii_geo(dec_dir)
#########################################################################
start_time = time.time()
####################### Voxelization #############
c_dec = ME.utils.batched_coordinates([part_dec])
f = torch.from_numpy(np.vstack(np.expand_dims(np.ones(c_dec.shape[0]), 1))).float()
x = ME.SparseTensor(features=f, coordinates=c_dec,device=device)
p2 = ME.utils.batched_coordinates([part_T])
##################################################
####################### GRNet ####################
with torch.no_grad():
out, _, _, _ = model(x, coords_T=p2, device=device, prune=True)
##################################################
run_time += time.time() - start_time
gpcc_list.append(x.C[:, 1:])
out_list.append(out.C[:,1:])
print("Point Cloud:",file_name)
print("bpp_coordinates:",bits_coordinates/num_gt)
run_time = round(run_time, 3)
rec_pc = torch.cat(out_list, 0)
gpcc_pc = torch.cat(gpcc_list, 0)
print("Number of points in G-PCC compressed point cloud : ", gpcc_pc.shape[0])
print("Number of points in GRNet restorated point cloud : ", rec_pc.cpu().numpy().shape[0])
rec_pcd = o3d.geometry.PointCloud()
recfile = os.path.join(save_path, file_name)
write_ply_ascii_geo(recfile, rec_pc.detach().cpu().numpy())
write_ply_ascii_geo(recfile + 'gpcc.ply', gpcc_pc.detach().cpu().numpy())
print('Run Time:\t', run_time, 's')
GT_dirs=args.GT_dir
pc_error_metrics = pc_error(GT_dirs+file_name, recfile, res=args.resolution, show=False)
pc_error_metrics_gpcc=pc_error(GT_dirs+file_name, recfile + 'gpcc.ply', res=args.resolution, show=False)
###########################Results#############################
results = {}
print('----------------GRNet PSNR D1-------------------')
print(pc_error_metrics["mseF,PSNR (p2point)"][0])
print('----------------GRNet PSNR D2-------------------')
print(pc_error_metrics["mseF,PSNR (p2plane)"][0])
print('----------------G-PCC PSNR D1-------------------')
print(pc_error_metrics_gpcc["mseF,PSNR (p2point)"][0])
print('----------------G-PCC PSNR D2-------------------')
print(pc_error_metrics_gpcc["mseF,PSNR (p2plane)"][0])
results["filename"] = file_name
results["bpp(coords)"] = bits_coordinates / num_gt
results['GRNet PSNR D1'] = pc_error_metrics["mseF,PSNR (p2point)"][0]
results['GRNet PSNR D2'] = pc_error_metrics["mseF,PSNR (p2plane)"][0]
results["run time"] = run_time
results['G-PCC PSNR D1'] = pc_error_metrics_gpcc["mseF,PSNR (p2point)"][0]
results['G-PCC PSNR D2'] = pc_error_metrics_gpcc["mseF,PSNR (p2plane)"][0]
results["gpcc enc time"] = results_gpcc['time (enc)']
results["gpcc dec time"] = results_gpcc['time (dec)']
################################################################
########################### Write to excel #####################
csv_name = os.path.join(save_path, 'all_result'+str(args.resolution) + '.csv')
results = pd.DataFrame([results])
if i == 0:
results_allfile = results.copy(deep=True)
else:
results_allfile = results_allfile.append(results, ignore_index=True)
csvfile = os.path.join(save_path, 'results'+str(args.resolution) + '.csv')
results_allfile.to_csv(csv_name, index=False)
print('Wrile results to: \t', csv_name)
i += 1
torch.cuda.empty_cache()
################################################################