-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluation_CC.py
116 lines (90 loc) · 4.39 KB
/
evaluation_CC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import numpy as np
import random
import yaml
import os
from data.xia_preprocess import generate_data
from arguments import parse_args_test
from model.mst import MST
def init_seed(seed=123):
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
init_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_bvh_files(directory):
return [os.path.join(directory, f) for f in sorted(list(os.listdir(directory)))
if os.path.isfile(os.path.join(directory, f))
and f.endswith('.bvh') and f != 'rest.bvh']
def normalize(x, mean, std):
x = (x - mean) / std
return x
def denormalize(x, mean, std):
x = x * std + mean
return x
if __name__ == '__main__':
eval_datapath = 'data/preprocessed_xia_test'
args = parse_args_test()
with open('xia_dataset.yml', "r") as f:
cfg = yaml.load(f, Loader=yaml.Loader)
# Load model
model = MST(False, cfg, args)
model = model.to(device)
model.load_checkpoint()
model.eval()
data_dist = np.load(args.dist_datapath)
Xmean = data_dist['Xmean']
Xstd = data_dist['Xstd']
bvh_files = get_bvh_files(eval_datapath)
content_full_namedict = [full_name.split('_')[0] for full_name in cfg["content_full_names"]]
euc_dist_sum = [0, 0]
num_test = [0, 0]
for i, item in enumerate(bvh_files):
filename = item.split('/')[-1]
style, content_num, _ = filename.split('_')
content = content_full_namedict[int(content_num) - 1]
print('.')
for i_ref, item_ref in enumerate(bvh_files):
filename_ref = item_ref.split('/')[-1]
style_ref, content_num_ref, _ = filename_ref.split('_')
content_ref = content_full_namedict[int(content_num_ref) - 1]
if (style_ref == style):
cnt_path = eval_datapath+'/'+ filename
sty_path = eval_datapath+'/'+ filename_ref
cnt_clip_raw, _ = generate_data(cnt_path, selected_joints=cfg["selected_joints"], njoints=cfg["njoints"], downsample=2)
sty_clip_raw, _ = generate_data(sty_path, selected_joints=cfg["selected_joints"], njoints=cfg["njoints"], downsample=2)
cnt_clip = normalize(cnt_clip_raw, Xmean, Xstd)
sty_clip = normalize(sty_clip_raw, Xmean, Xstd)
cnt_clip = torch.tensor(cnt_clip, dtype=torch.float).unsqueeze(0).cuda()
sty_clip = torch.tensor(sty_clip, dtype=torch.float).unsqueeze(0).cuda()
# Generate temporal mask for the motion sequences & change nan to 0.0
cnt_m = cnt_clip[:,1,:,0]
cnt_length = sum(~torch.isnan(cnt_m[0])).cpu().numpy()
cnt_mask = ~torch.isnan(cnt_m).unsqueeze(1).repeat(1, cnt_m.size(1), 1).unsqueeze(1)
cnt_clip[torch.isnan(cnt_clip)] = 0.0
sty_m = sty_clip[:,1,:,0]
sty_mask = ~torch.isnan(sty_m).unsqueeze(1).repeat(1, sty_m.size(1), 1).unsqueeze(1)
sty_clip[torch.isnan(sty_clip)] = 0.0
# Perform style transfer
gen = model.generator(cnt_clip, sty_clip, cnt_mask, sty_mask)
# Our model generates global translation & local pose sequence
gen_traj = gen[0,cfg["joint_dims"]:,:cnt_length,:].cpu().detach().numpy()
gen_body = gen[0, :cfg["joint_dims"], :cnt_length,:].cpu().detach().numpy()
# Calculate Content Consistency (CC)
# Local pose sequence is used for evaluation (because other existing methods cannot generate global translation)
gt = cnt_clip_raw[:3,:cnt_length,:]
gen_body_denorm = denormalize(gen_body, Xmean[:7], Xstd[:7])
pred = gen_body_denorm[:3,:,:]
euc_dist = np.sum(np.linalg.norm(gt-pred, axis=0), axis=(0,1)) / cnt_length
if (content_ref == content):
euc_dist_sum[0] += euc_dist
num_test[0] += 1
else:
euc_dist_sum[1] += euc_dist
num_test[1] += 1
print('CC_same_cnt: ', euc_dist_sum[0]/num_test[0])
print('CC_diff_cnt: ', euc_dist_sum[1]/num_test[1])
print('CC_total: ', np.sum(euc_dist_sum)/np.sum(num_test))