def compute_interaction_loss()

in phosa/global_opt.py [0:0]


    def compute_interaction_loss(self, verts_person, verts_object):
        """
        Computes interaction loss.
        """
        loss_interaction = torch.tensor(0.0).float().cuda()
        interaction_pairs = self.assign_interaction_pairs(verts_person, verts_object)
        for i_person, i_object in interaction_pairs:
            v_person = verts_person[i_person]
            v_object = verts_object[i_object]
            centroid_error = self.mse(v_person.mean(0), v_object.mean(0))
            loss_interaction += centroid_error
        num_interactions = max(len(interaction_pairs), 1)
        return {"loss_inter": loss_interaction / num_interactions}