in main/model.py [0:0]
def forward(self, inputs, targets, meta_info, mode):
input_img = inputs['img']
batch_size = input_img.shape[0]
mesh = self.get_mesh_data(batch_size)
align_joint_idx = torch.Tensor(self.align_joint_idx).long()
# extract image feature
img_feat = self.backbone_net(input_img)
# estimate local euler angle change for each joint
joint_euler = self.pose_net(img_feat)
joint_rot_mat = euler2mat(joint_euler, to_4x4=True)
# estimate skeleton corrective
skeleton_corrective = self.skeleton_refine_net(self.id_code)
mesh['local_pose_refined'] = mesh['local_pose'].clone()
mesh['local_pose_refined'][:,:3,3] += skeleton_corrective
mesh['global_pose_refined'] = [None for _ in range(self.joint_num)]
mesh['global_pose_refined'][self.root_joint_idx] = mesh['global_pose'][self.root_joint_idx].clone()
forward_kinematics(self.skeleton, self.root_joint_idx, mesh['local_pose_refined'], mesh['global_pose_refined'])
mesh['global_pose_refined'] = torch.stack(mesh['global_pose_refined'])
# rigid transform for root joint
global_pose = [[None for _ in range(self.joint_num)] for _ in range(batch_size)]
if mode == 'train':
for i in range(batch_size):
cur_sample = mesh['global_pose_refined'][align_joint_idx,:3,3].view(len(align_joint_idx),3)
gt_sample = targets['joint']['world_coord'][i][align_joint_idx].view(len(align_joint_idx),3)
R,t = rigid_transform_3D(cur_sample, gt_sample)
mat = torch.cat((torch.cat((R,t),1), torch.cuda.FloatTensor([[0,0,0,1]])))
global_pose[i][self.root_joint_idx] = torch.mm(mat, mesh['global_pose_refined'][self.root_joint_idx])
elif mode == 'test':
for i in range(batch_size):
global_pose[i][self.root_joint_idx] = torch.eye(4).float().cuda() # use identity matrix in testing stage
# forward kinematics
joint_out = []; joint_trans_mat = [];
for i in range(batch_size):
forward_kinematics(self.skeleton, self.root_joint_idx, torch.bmm(mesh['local_pose_refined'], joint_rot_mat[i]), global_pose[i])
joint_out.append(torch.cat([global_pose[i][j][None,:3,3] for j in range(self.joint_num)],0))
joint_trans_mat.append(torch.cat([torch.mm(global_pose[i][j], mesh['global_pose_inv'][j,:,:])[None,:,:] for j in range(self.joint_num)]))
joint_out = torch.cat(joint_out).view(batch_size,self.joint_num,3)
joint_trans_mat = torch.cat(joint_trans_mat).view(batch_size,self.joint_num,4,4).permute(1,0,2,3)
# estimate corrective vector
pose_corrective, id_corrective = self.skin_refine_net(joint_euler.detach(), self.id_code[None,:].repeat(batch_size,1))
mesh_refined_xyz = mesh['v'] + pose_corrective + id_corrective
# LBS
mesh_refined_xyz1 = torch.cat([mesh_refined_xyz, torch.ones_like(mesh_refined_xyz[:,:,:1])],2)
mesh_out_refined = sum([mesh['skinning_weight'][:,:,j,None]*torch.bmm(joint_trans_mat[j],mesh_refined_xyz1.permute(0,2,1)).permute(0,2,1)[:,:,:3] for j in range(self.joint_num)])
# loss functions in training stage
if mode == 'train':
# render depthmap
cam_param, affine_trans = meta_info['cam_param'], meta_info['affine_trans']
depthmap_out_refined = []
for cid in range(len(cam_param)):
rendered_depthmap_refined = self.renderer(mesh_out_refined, cam_param[cid], affine_trans[cid], mesh)
depthmap_out_refined.append(rendered_depthmap_refined)
# zero pose template mesh with correctives (for penet loss and test output)
joint_trans_mat = torch.bmm(mesh['global_pose_refined'], mesh['global_pose_inv'])[:,None,:,:].repeat(1,batch_size,1,1)
mesh_refined_v = mesh['v'] + pose_corrective + id_corrective
mesh_refined_v = torch.cat([mesh_refined_v, torch.ones_like(mesh_refined_v[:,:,:1])],2)
mesh_refined_v = sum([mesh['skinning_weight'][:,:,j,None]*torch.bmm(joint_trans_mat[j],mesh_refined_v.permute(0,2,1)).permute(0,2,1)[:,:,:3] for j in range(self.joint_num)])
loss = {}
loss['joint'] = self.joint_loss(joint_out, targets['joint']['world_coord'], targets['joint']['valid'])
loss['depthmap'] = self.depthmap_loss(depthmap_out_refined, targets['depthmap'])
loss['penet'] = self.penet_loss(mesh['global_pose_refined'].detach(), mesh_refined_v.detach(), joint_out, mesh_out_refined)
loss['lap'] = self.lap_loss(mesh_out_refined) * cfg.loss_lap_weight
return loss
# output in testing stage
elif mode == 'test':
out = {}
out['joint_out'] = joint_out
out['mesh_out_refined'] = mesh_out_refined
return out