in main/model.py [0:0]
def forward(self, inputs, targets, meta_info, mode):
input_img = inputs['img']
batch_size = input_img.shape[0]
img_feat = self.backbone_net(input_img)
joint_heatmap_out, rel_root_depth_out, hand_type = self.pose_net(img_feat)
if mode == 'train':
target_joint_heatmap = self.render_gaussian_heatmap(targets['joint_coord'])
loss = {}
loss['joint_heatmap'] = self.joint_heatmap_loss(joint_heatmap_out, target_joint_heatmap, meta_info['joint_valid'])
loss['rel_root_depth'] = self.rel_root_depth_loss(rel_root_depth_out, targets['rel_root_depth'], meta_info['root_valid'])
loss['hand_type'] = self.hand_type_loss(hand_type, targets['hand_type'], meta_info['hand_type_valid'])
return loss
elif mode == 'test':
out = {}
val_z, idx_z = torch.max(joint_heatmap_out,2)
val_zy, idx_zy = torch.max(val_z,2)
val_zyx, joint_x = torch.max(val_zy,2)
joint_x = joint_x[:,:,None]
joint_y = torch.gather(idx_zy, 2, joint_x)
joint_z = torch.gather(idx_z, 2, joint_y[:,:,:,None].repeat(1,1,1,cfg.output_hm_shape[1]))[:,:,0,:]
joint_z = torch.gather(joint_z, 2, joint_x)
joint_coord_out = torch.cat((joint_x, joint_y, joint_z),2).float()
out['joint_coord'] = joint_coord_out
out['rel_root_depth'] = rel_root_depth_out
out['hand_type'] = hand_type
if 'inv_trans' in meta_info:
out['inv_trans'] = meta_info['inv_trans']
if 'joint_coord' in targets:
out['target_joint'] = targets['joint_coord']
if 'joint_valid' in meta_info:
out['joint_valid'] = meta_info['joint_valid']
if 'hand_type_valid' in meta_info:
out['hand_type_valid'] = meta_info['hand_type_valid']
return out