def forward()

in shapenet/modeling/heads/mesh_loss.py [0:0]


    def forward(self, voxel_scores, meshes_pred, voxels_gt, meshes_gt):
        """
        Args:
          meshes_pred: Meshes
          meshes_gt: Either Meshes, or a tuple (points_gt, normals_gt)

        Returns:
          loss (float): Torch scalar giving the total loss, or None if an error occured and
                we should skip this loss. TODO use an exception instead?
          losses (dict): A dictionary mapping loss names to Torch scalars giving their
                        (unweighted) values.
        """
        # Sample from meshes_gt if we haven't already
        if isinstance(meshes_gt, tuple):
            points_gt, normals_gt = meshes_gt
        else:
            points_gt, normals_gt = sample_points_from_meshes(
                meshes_gt, num_samples=self.gt_num_samples, return_normals=True
            )

        total_loss = torch.tensor(0.0).to(points_gt)
        losses = {}

        if voxel_scores is not None and voxels_gt is not None and self.voxel_weight > 0:
            voxels_gt = voxels_gt.float()
            voxel_loss = F.binary_cross_entropy_with_logits(voxel_scores, voxels_gt)
            total_loss = total_loss + self.voxel_weight * voxel_loss
            losses["voxel"] = voxel_loss

        if isinstance(meshes_pred, Meshes):
            meshes_pred = [meshes_pred]
        elif meshes_pred is None:
            meshes_pred = []

        # Now assume meshes_pred is a list
        if not self.skip_mesh_loss:
            for i, cur_meshes_pred in enumerate(meshes_pred):
                cur_out = self._mesh_loss(cur_meshes_pred, points_gt, normals_gt)
                cur_loss, cur_losses = cur_out
                if total_loss is None or cur_loss is None:
                    total_loss = None
                else:
                    total_loss = total_loss + cur_loss / len(meshes_pred)
                for k, v in cur_losses.items():
                    losses["%s_%d" % (k, i)] = v

        return total_loss, losses