in common/nets/loss.py [0:0]
def forward(self, template_pose, mesh_v, joint_out, geo_out):
batch_size = joint_out.shape[0]
bone = self.get_bone_data(template_pose, mesh_v)
bone_out_from_joint = self.get_bone_from_joint(joint_out, bone)
skeleton_path = []
self.traverse_skeleton(self.root_joint_idx, [], skeleton_path)
skeleton_part = []
for path in skeleton_path:
for pid in range(len(path)-1):
start_joint_idx = path[pid]; end_joint_idx = path[pid+1]
skeleton_part.append([start_joint_idx, end_joint_idx])
skeleton_part = self.make_combination(skeleton_part) # (combination num x 2 (part_1, part_2) x 2 (start, end joint idx))
# rigid part
loss_penetration_rigid = 0
loss_penetration_rigid_cnt = 0
for cid in range(len(skeleton_part)):
# first part index
start_joint_idx_1 = skeleton_part[cid][0][0]
end_joint_idx_1 = skeleton_part[cid][0][1]
# second part index
start_joint_idx_2 = skeleton_part[cid][1][0]
end_joint_idx_2 = skeleton_part[cid][1][1]
# exclude adjant parts
if start_joint_idx_1 == start_joint_idx_2 or start_joint_idx_1 == end_joint_idx_2 or end_joint_idx_1 == start_joint_idx_2:
continue
# first part distance thr (radius of sphere)
bone_mask_1 = ((bone['start_joint_idx'] == start_joint_idx_1) * (bone['end_joint_idx'] == end_joint_idx_1)).byte()
if torch.sum(bone_mask_1) == 0:
continue
bone_1 = bone_out_from_joint[:,bone_mask_1,:]
dist_thr_1 = bone['min_dist'][bone_mask_1].permute(1,0)
# second part distance thr (radius of sphere)
bone_mask_2 = ((bone['start_joint_idx'] == start_joint_idx_2) * (bone['end_joint_idx'] == end_joint_idx_2)).byte()
if torch.sum(bone_mask_2) == 0:
continue
bone_2 = bone_out_from_joint[:,bone_mask_2,:]
dist_thr_2 = bone['min_dist'][bone_mask_2].permute(1,0)
# loss calculate
dist = torch.sqrt(torch.sum((bone_1[:,:,None,:].repeat(1,1,bone_2.shape[1],1) - bone_2[:,None,:,:].repeat(1,bone_1.shape[1],1,1))**2,3))
dist_thr = dist_thr_1[:,:,None].repeat(1,1,dist_thr_2.shape[1]) + dist_thr_2[:,None,:].repeat(1,dist_thr_1.shape[1],1)
loss_penetration_rigid += torch.clamp(dist_thr - dist, min=0).mean((1,2))
loss_penetration_rigid_cnt += 1
# non-rigid joint
loss_penetration_non_rigid = 0
loss_penetration_non_rigid_cnt = 0
for nr_jid in self.non_rigid_joint_idx:
nr_skin = geo_out[:,(self.segmentation == nr_jid).byte(),:]
for path in skeleton_path:
is_penetrating = torch.cuda.FloatTensor([0 for _ in range(batch_size)])
# only consider finger tips
for pid in range(len(path)-2,len(path)-1):
start_joint_idx = path[pid]; end_joint_idx = path[pid+1];
# exclude path from root through the thumb path
if 'thumb' in self.skeleton[start_joint_idx]['name'] or 'thumb' in self.skeleton[end_joint_idx]['name']:
continue
bone_mask = ((bone['start_joint_idx'] == start_joint_idx) * (bone['end_joint_idx'] == end_joint_idx)).byte()
bone_pid = bone_out_from_joint[:,bone_mask,:]
dist = torch.sqrt(torch.sum((nr_skin[:,:,None,:].repeat(1,1,bone_pid.shape[1],1) - bone_pid[:,None,:,:].repeat(1,nr_skin.shape[1],1,1))**2,3))
dist = torch.min(dist,1)[0] # use minimum distance from a bone to skin
dist_thr = bone['min_dist'][bone_mask].permute(1,0)
loss_per_batch = []
for bid in range(batch_size):
collision_idx = torch.nonzero(dist[bid] < dist_thr[bid])
if len(collision_idx) > 0: # collision occur
is_penetrating[bid] = 1
bone_idx = torch.min(collision_idx) # bone start from parent to child -> just pick min idx
loss = torch.abs((dist[bid][bone_idx:] - dist_thr[bid][bone_idx:]).mean()).view(1)
elif is_penetrating[bid] == 1:
loss = torch.abs((dist[bid] - dist_thr[bid]).mean()).view(1)
else:
loss = torch.zeros((1)).cuda().float()
loss_per_batch.append(loss)
loss_penetration_non_rigid += torch.cat(loss_per_batch)
loss_penetration_non_rigid_cnt += 1
loss_penetration_rigid = loss_penetration_rigid / loss_penetration_rigid_cnt
loss_penetration_non_rigid = loss_penetration_non_rigid / loss_penetration_non_rigid_cnt
loss = cfg.loss_penet_r_weight * loss_penetration_rigid + cfg.loss_penet_nr_weight * loss_penetration_non_rigid
return loss