def forward()

in common/nets/loss.py [0:0]


    def forward(self, template_pose, mesh_v, joint_out, geo_out):
        batch_size = joint_out.shape[0]
    
        bone = self.get_bone_data(template_pose, mesh_v)
        bone_out_from_joint = self.get_bone_from_joint(joint_out, bone)

        skeleton_path = []
        self.traverse_skeleton(self.root_joint_idx, [], skeleton_path)
        skeleton_part = []
        for path in skeleton_path:
            for pid in range(len(path)-1):
                start_joint_idx = path[pid]; end_joint_idx = path[pid+1]
                skeleton_part.append([start_joint_idx, end_joint_idx])
        skeleton_part = self.make_combination(skeleton_part) # (combination num x 2 (part_1, part_2) x 2 (start, end joint idx))
        
        # rigid part
        loss_penetration_rigid = 0
        loss_penetration_rigid_cnt = 0
        for cid in range(len(skeleton_part)):
            # first part index
            start_joint_idx_1 = skeleton_part[cid][0][0]
            end_joint_idx_1 = skeleton_part[cid][0][1]
 
            # second part index
            start_joint_idx_2 = skeleton_part[cid][1][0]
            end_joint_idx_2 = skeleton_part[cid][1][1]
            
            # exclude adjant parts
            if start_joint_idx_1 == start_joint_idx_2 or start_joint_idx_1 == end_joint_idx_2 or end_joint_idx_1 == start_joint_idx_2:
                continue
            
            # first part distance thr (radius of sphere)
            bone_mask_1 = ((bone['start_joint_idx'] == start_joint_idx_1) * (bone['end_joint_idx'] == end_joint_idx_1)).byte()
            if torch.sum(bone_mask_1) == 0:
                continue
            bone_1 = bone_out_from_joint[:,bone_mask_1,:]
            dist_thr_1 = bone['min_dist'][bone_mask_1].permute(1,0)

            # second part distance thr (radius of sphere)
            bone_mask_2 = ((bone['start_joint_idx'] == start_joint_idx_2) * (bone['end_joint_idx'] == end_joint_idx_2)).byte()
            if torch.sum(bone_mask_2) == 0:
                continue
            bone_2 = bone_out_from_joint[:,bone_mask_2,:]
            dist_thr_2 = bone['min_dist'][bone_mask_2].permute(1,0)

            # loss calculate
            dist = torch.sqrt(torch.sum((bone_1[:,:,None,:].repeat(1,1,bone_2.shape[1],1) - bone_2[:,None,:,:].repeat(1,bone_1.shape[1],1,1))**2,3))
            dist_thr = dist_thr_1[:,:,None].repeat(1,1,dist_thr_2.shape[1]) + dist_thr_2[:,None,:].repeat(1,dist_thr_1.shape[1],1)
            loss_penetration_rigid += torch.clamp(dist_thr - dist, min=0).mean((1,2))
            loss_penetration_rigid_cnt += 1

        # non-rigid joint
        loss_penetration_non_rigid = 0
        loss_penetration_non_rigid_cnt = 0
        for nr_jid in self.non_rigid_joint_idx:
            nr_skin = geo_out[:,(self.segmentation == nr_jid).byte(),:]
            for path in skeleton_path:
                is_penetrating = torch.cuda.FloatTensor([0 for _ in range(batch_size)])

                # only consider finger tips
                for pid in range(len(path)-2,len(path)-1):
                    start_joint_idx = path[pid]; end_joint_idx = path[pid+1];

                    # exclude path from root through the thumb path
                    if 'thumb' in self.skeleton[start_joint_idx]['name'] or 'thumb' in self.skeleton[end_joint_idx]['name']:
                        continue

                    bone_mask = ((bone['start_joint_idx'] == start_joint_idx) * (bone['end_joint_idx'] == end_joint_idx)).byte()
                    bone_pid = bone_out_from_joint[:,bone_mask,:]
                    dist = torch.sqrt(torch.sum((nr_skin[:,:,None,:].repeat(1,1,bone_pid.shape[1],1) - bone_pid[:,None,:,:].repeat(1,nr_skin.shape[1],1,1))**2,3))
                    dist = torch.min(dist,1)[0] # use minimum distance from a bone to skin
                    dist_thr = bone['min_dist'][bone_mask].permute(1,0)
                    
                    loss_per_batch = []
                    for bid in range(batch_size):
                        collision_idx = torch.nonzero(dist[bid] < dist_thr[bid])
                        if len(collision_idx) > 0: # collision occur
                            is_penetrating[bid] = 1
                            bone_idx = torch.min(collision_idx) # bone start from parent to child -> just pick min idx
                            loss = torch.abs((dist[bid][bone_idx:] - dist_thr[bid][bone_idx:]).mean()).view(1)
                        elif is_penetrating[bid] == 1:
                            loss = torch.abs((dist[bid] - dist_thr[bid]).mean()).view(1)
                        else:
                            loss = torch.zeros((1)).cuda().float()
                        loss_per_batch.append(loss)
                    loss_penetration_non_rigid += torch.cat(loss_per_batch)
                    loss_penetration_non_rigid_cnt += 1

        loss_penetration_rigid = loss_penetration_rigid / loss_penetration_rigid_cnt
        loss_penetration_non_rigid = loss_penetration_non_rigid / loss_penetration_non_rigid_cnt
        loss = cfg.loss_penet_r_weight * loss_penetration_rigid + cfg.loss_penet_nr_weight * loss_penetration_non_rigid
        return loss