in phosa/global_opt.py [0:0]
def forward(self, loss_weights=None):
"""
If a loss weight is zero, that loss isn't computed (to avoid unnecessary
compute).
"""
loss_dict = {}
verts_object = self.get_verts_object()
verts_person = self.get_verts_person()
if loss_weights is None or loss_weights["lw_sil"] > 0:
loss_dict.update(
self.losses.compute_sil_loss(
verts=verts_object, faces=[self.faces_object] * len(verts_object)
)
)
if loss_weights is None or loss_weights["lw_inter"] > 0:
loss_dict.update(
self.losses.compute_interaction_loss(
verts_person=verts_person, verts_object=verts_object
)
)
if loss_weights is None or loss_weights["lw_inter_part"] > 0:
loss_dict.update(
self.losses.compute_interaction_loss_parts(
verts_person=verts_person, verts_object=verts_object
)
)
if loss_weights is None or loss_weights["lw_scale"] > 0:
loss_dict["loss_scale"] = self.losses.compute_intrinsic_scale_prior(
intrinsic_scales=self.int_scales_object,
intrinsic_mean=self.int_scale_object_mean,
)
if loss_weights is None or loss_weights["lw_scale_person"] > 0:
loss_dict["loss_scale_person"] = self.losses.compute_intrinsic_scale_prior(
intrinsic_scales=self.int_scales_person,
intrinsic_mean=self.int_scale_person_mean,
)
if loss_weights is None or loss_weights["lw_depth"] > 0:
loss_dict.update(self.compute_ordinal_depth_loss())
return loss_dict