-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathnet.py
100 lines (90 loc) · 4.67 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from torch.nn.utils.rnn import *
import articulate as art
from articulate.utils.torch import *
from config import *
from utils import *
from dynamics import PhysicsOptimizer
class PIP(torch.nn.Module):
name = 'PIP'
n_hidden = 256
def __init__(self):
super(PIP, self).__init__()
self.rnn1 = RNNWithInit(input_size=72,
output_size=joint_set.n_leaf * 3,
hidden_size=self.n_hidden,
num_rnn_layer=2,
dropout=0.4)
self.rnn2 = RNN(input_size=72 + joint_set.n_leaf * 3,
output_size=joint_set.n_full * 3,
hidden_size=self.n_hidden,
num_rnn_layer=2,
dropout=0.4)
self.rnn3 = RNN(input_size=72 + joint_set.n_full * 3,
output_size=joint_set.n_reduced * 6,
hidden_size=self.n_hidden,
num_rnn_layer=2,
dropout=0.4)
self.rnn4 = RNNWithInit(input_size=72 + joint_set.n_full * 3,
output_size=24 * 3,
hidden_size=self.n_hidden,
num_rnn_layer=2,
dropout=0.4)
self.rnn5 = RNN(input_size=72 + joint_set.n_full * 3,
output_size=2,
hidden_size=64,
num_rnn_layer=2,
dropout=0.4)
body_model = art.ParametricModel(paths.smpl_file)
self.inverse_kinematics_R = body_model.inverse_kinematics_R
self.forward_kinematics = body_model.forward_kinematics
self.dynamics_optimizer = PhysicsOptimizer(debug=False)
self.rnn_states = [None for _ in range(5)]
self.load_state_dict(torch.load(paths.weights_file))
self.eval()
def _reduced_glb_6d_to_full_local_mat(self, root_rotation, glb_reduced_pose):
glb_reduced_pose = art.math.r6d_to_rotation_matrix(glb_reduced_pose).view(-1, joint_set.n_reduced, 3, 3)
global_full_pose = torch.eye(3, device=glb_reduced_pose.device).repeat(glb_reduced_pose.shape[0], 24, 1, 1)
global_full_pose[:, joint_set.reduced] = glb_reduced_pose
pose = self.inverse_kinematics_R(global_full_pose).view(-1, 24, 3, 3)
pose[:, joint_set.ignored] = torch.eye(3, device=pose.device)
pose[:, 0] = root_rotation.view(-1, 3, 3)
return pose
def forward(self, x):
r"""
Forward.
:param x: A list in length [batch_size] which contains 3-tuple
(tensor [num_frames, 72], tensor [15], tensor [72]).
"""
x, lj_init, jvel_init = list(zip(*x))
leaf_joint = self.rnn1(list(zip(x, lj_init)))
full_joint = self.rnn2([torch.cat(_, dim=-1) for _ in zip(leaf_joint, x)])
global_6d_pose = self.rnn3([torch.cat(_, dim=-1) for _ in zip(full_joint, x)])
joint_velocity = self.rnn4(list(zip([torch.cat(_, dim=-1) for _ in zip(full_joint, x)], jvel_init)))
contact = self.rnn5([torch.cat(_, dim=-1) for _ in zip(full_joint, x)])
return leaf_joint, full_joint, global_6d_pose, joint_velocity, contact
@torch.no_grad()
def predict(self, glb_acc, glb_rot, init_pose):
r"""
Predict the results for evaluation.
:param glb_acc: A tensor that can reshape to [num_frames, 6, 3].
:param glb_rot: A tensor that can reshape to [num_frames, 6, 3, 3].
:param init_pose: A tensor that can reshape to [1, 24, 3, 3].
:return: Pose tensor in shape [num_frames, 24, 3, 3] and
translation tensor in shape [num_frames, 3].
"""
self.dynamics_optimizer.reset_states()
init_pose = init_pose.view(1, 24, 3, 3)
init_pose[0, 0] = torch.eye(3)
lj_init = self.forward_kinematics(init_pose)[1][0, joint_set.leaf].view(-1)
jvel_init = torch.zeros(24 * 3)
x = (normalize_and_concat(glb_acc, glb_rot), lj_init, jvel_init)
leaf_joint, full_joint, global_6d_pose, joint_velocity, contact = [_[0] for _ in self.forward([x])]
pose = self._reduced_glb_6d_to_full_local_mat(glb_rot.view(-1, 6, 3, 3)[:, -1], global_6d_pose)
joint_velocity = joint_velocity.view(-1, 24, 3).bmm(glb_rot[:, -1].transpose(1, 2)) * vel_scale
pose_opt, tran_opt = [], []
for p, v, c, a in zip(pose, joint_velocity, contact, glb_acc):
p, t = self.dynamics_optimizer.optimize_frame(p, v, c, a)
pose_opt.append(p)
tran_opt.append(t)
pose_opt, tran_opt = torch.stack(pose_opt), torch.stack(tran_opt)
return pose_opt, tran_opt