in main/model.py [0:0]
def decode(self, joint_euler):
batch_size = joint_euler.shape[0]
mesh = self.get_mesh_data(batch_size)
joint_rot_mat = euler2mat(joint_euler, to_4x4=True)
# estimate skeleton corrective
skeleton_corrective = self.skeleton_refine_net(self.id_code)
mesh['local_pose_refined'] = mesh['local_pose'].clone()
mesh['local_pose_refined'][:,:3,3] += skeleton_corrective
mesh['global_pose_refined'] = [None for _ in range(self.joint_num)]
mesh['global_pose_refined'][self.root_joint_idx] = mesh['global_pose'][self.root_joint_idx].clone()
forward_kinematics(self.skeleton, self.root_joint_idx, mesh['local_pose_refined'], mesh['global_pose_refined'])
mesh['global_pose_refined'] = torch.stack(mesh['global_pose_refined'])
# rigid transform for root joint
global_pose = [[None for _ in range(self.joint_num)] for _ in range(batch_size)]
for i in range(batch_size):
global_pose[i][self.root_joint_idx] = torch.eye(4).float().cuda() # use identity matrix in testing stage
# forward kinematics
joint_out = []; joint_trans_mat = [];
for i in range(batch_size):
forward_kinematics(self.skeleton, self.root_joint_idx, torch.bmm(mesh['local_pose_refined'], joint_rot_mat[i]), global_pose[i])
joint_out.append(torch.cat([global_pose[i][j][None,:3,3] for j in range(self.joint_num)],0))
joint_trans_mat.append(torch.cat([torch.mm(global_pose[i][j], mesh['global_pose_inv'][j,:,:])[None,:,:] for j in range(self.joint_num)]))
joint_out = torch.cat(joint_out).view(batch_size,self.joint_num,3)
joint_trans_mat = torch.cat(joint_trans_mat).view(batch_size,self.joint_num,4,4).permute(1,0,2,3)
# estimate corrective vector
pose_corrective, id_corrective = self.skin_refine_net(joint_euler.detach(), self.id_code[None,:].repeat(batch_size,1))
mesh_refined_xyz = mesh['v'] + pose_corrective + id_corrective
# LBS
mesh_refined_xyz1 = torch.cat([mesh_refined_xyz, torch.ones_like(mesh_refined_xyz[:,:,:1])],2)
mesh_out_refined = sum([mesh['skinning_weight'][:,:,j,None]*torch.bmm(joint_trans_mat[j],mesh_refined_xyz1.permute(0,2,1)).permute(0,2,1)[:,:,:3] for j in range(self.joint_num)])
out = {}
out['joint_out'] = joint_out
out['mesh_out_refined'] = mesh_out_refined
return out