in meshrcnn/modeling/roi_heads/voxel_head.py [0:0]
def voxel_rcnn_loss(pred_voxel_logits, instances, loss_weight=1.0):
"""
Compute the voxel prediction loss defined in the Mesh R-CNN paper.
Args:
pred_voxel_logits (Tensor): A tensor of shape (B, C, D, H, W) or (B, 1, D, H, W)
for class-specific or class-agnostic, where B is the total number of predicted voxels
in all images, C is the number of foreground classes, and D, H, W are the depth,
height and width of the voxel predictions. The values are logits.
instances (list[Instances]): A list of N Instances, where N is the number of images
in the batch. These instances are in 1:1
correspondence with the pred_voxel_logits. The ground-truth labels (class, box, mask,
...) associated with each instance are stored in fields.
loss_weight (float): A float to multiply the loss with.
Returns:
voxel_loss (Tensor): A scalar tensor containing the loss.
"""
cls_agnostic_voxel = pred_voxel_logits.size(1) == 1
total_num_voxels = pred_voxel_logits.size(0)
voxel_side_len = pred_voxel_logits.size(2)
assert pred_voxel_logits.size(2) == pred_voxel_logits.size(
3
), "Voxel prediction must be square!"
assert pred_voxel_logits.size(2) == pred_voxel_logits.size(
4
), "Voxel prediction must be square!"
gt_classes = []
gt_voxel_logits = []
for instances_per_image in instances:
if len(instances_per_image) == 0:
continue
if not cls_agnostic_voxel:
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
gt_classes.append(gt_classes_per_image)
gt_voxels = instances_per_image.gt_voxels
gt_K = instances_per_image.gt_K
gt_voxel_logits_per_image = batch_crop_voxels_within_box(
gt_voxels, instances_per_image.proposal_boxes.tensor, gt_K, voxel_side_len
).to(device=pred_voxel_logits.device)
gt_voxel_logits.append(gt_voxel_logits_per_image)
if len(gt_voxel_logits) == 0:
return pred_voxel_logits.sum() * 0, gt_voxel_logits
gt_voxel_logits = cat(gt_voxel_logits, dim=0)
assert gt_voxel_logits.numel() > 0, gt_voxel_logits.shape
if cls_agnostic_voxel:
pred_voxel_logits = pred_voxel_logits[:, 0]
else:
indices = torch.arange(total_num_voxels)
gt_classes = cat(gt_classes, dim=0)
pred_voxel_logits = pred_voxel_logits[indices, gt_classes]
# Log the training accuracy (using gt classes and 0.5 threshold)
# Note that here we allow gt_voxel_logits to be float as well
# (depend on the implementation of rasterize())
voxel_accurate = (pred_voxel_logits > 0.5) == (gt_voxel_logits > 0.5)
voxel_accuracy = voxel_accurate.nonzero().size(0) / voxel_accurate.numel()
get_event_storage().put_scalar("voxel_rcnn/accuracy", voxel_accuracy)
voxel_loss = F.binary_cross_entropy_with_logits(
pred_voxel_logits, gt_voxel_logits, reduction="mean"
)
voxel_loss = voxel_loss * loss_weight
return voxel_loss, gt_voxel_logits