in shapenet/modeling/heads/mesh_loss.py [0:0]
def forward(self, voxel_scores, meshes_pred, voxels_gt, meshes_gt):
"""
Args:
meshes_pred: Meshes
meshes_gt: Either Meshes, or a tuple (points_gt, normals_gt)
Returns:
loss (float): Torch scalar giving the total loss, or None if an error occured and
we should skip this loss. TODO use an exception instead?
losses (dict): A dictionary mapping loss names to Torch scalars giving their
(unweighted) values.
"""
# Sample from meshes_gt if we haven't already
if isinstance(meshes_gt, tuple):
points_gt, normals_gt = meshes_gt
else:
points_gt, normals_gt = sample_points_from_meshes(
meshes_gt, num_samples=self.gt_num_samples, return_normals=True
)
total_loss = torch.tensor(0.0).to(points_gt)
losses = {}
if voxel_scores is not None and voxels_gt is not None and self.voxel_weight > 0:
voxels_gt = voxels_gt.float()
voxel_loss = F.binary_cross_entropy_with_logits(voxel_scores, voxels_gt)
total_loss = total_loss + self.voxel_weight * voxel_loss
losses["voxel"] = voxel_loss
if isinstance(meshes_pred, Meshes):
meshes_pred = [meshes_pred]
elif meshes_pred is None:
meshes_pred = []
# Now assume meshes_pred is a list
if not self.skip_mesh_loss:
for i, cur_meshes_pred in enumerate(meshes_pred):
cur_out = self._mesh_loss(cur_meshes_pred, points_gt, normals_gt)
cur_loss, cur_losses = cur_out
if total_loss is None or cur_loss is None:
total_loss = None
else:
total_loss = total_loss + cur_loss / len(meshes_pred)
for k, v in cur_losses.items():
losses["%s_%d" % (k, i)] = v
return total_loss, losses