in phosa/global_opt.py [0:0]
def compute_interaction_loss(self, verts_person, verts_object):
"""
Computes interaction loss.
"""
loss_interaction = torch.tensor(0.0).float().cuda()
interaction_pairs = self.assign_interaction_pairs(verts_person, verts_object)
for i_person, i_object in interaction_pairs:
v_person = verts_person[i_person]
v_object = verts_object[i_object]
centroid_error = self.mse(v_person.mean(0), v_object.mean(0))
loss_interaction += centroid_error
num_interactions = max(len(interaction_pairs), 1)
return {"loss_inter": loss_interaction / num_interactions}