in criterion.py [0:0]
def loss_angle(self, outputs, targets, assignments):
angle_logits = outputs["angle_logits"]
angle_residual = outputs["angle_residual_normalized"]
if targets["num_boxes_replica"] > 0:
gt_angle_label = targets["gt_angle_class_label"]
gt_angle_residual = targets["gt_angle_residual_label"]
gt_angle_residual_normalized = gt_angle_residual / (
np.pi / self.dataset_config.num_angle_bin
)
# # Non vectorized version
# assignments = assignments["assignments"]
# p_angle_logits = []
# p_angle_resid = []
# t_angle_labels = []
# t_angle_resid = []
# for b in range(angle_logits.shape[0]):
# if len(assignments[b]) > 0:
# p_angle_logits.append(angle_logits[b, assignments[b][0]])
# p_angle_resid.append(angle_residual[b, assignments[b][0], gt_angle_label[b][assignments[b][1]]])
# t_angle_labels.append(gt_angle_label[b, assignments[b][1]])
# t_angle_resid.append(gt_angle_residual_normalized[b, assignments[b][1]])
# p_angle_logits = torch.cat(p_angle_logits)
# p_angle_resid = torch.cat(p_angle_resid)
# t_angle_labels = torch.cat(t_angle_labels)
# t_angle_resid = torch.cat(t_angle_resid)
# angle_cls_loss = F.cross_entropy(p_angle_logits, t_angle_labels, reduction="sum")
# angle_reg_loss = huber_loss(p_angle_resid.flatten() - t_angle_resid.flatten()).sum()
gt_angle_label = torch.gather(
gt_angle_label, 1, assignments["per_prop_gt_inds"]
)
angle_cls_loss = F.cross_entropy(
angle_logits.transpose(2, 1), gt_angle_label, reduction="none"
)
angle_cls_loss = (
angle_cls_loss * assignments["proposal_matched_mask"]
).sum()
gt_angle_residual_normalized = torch.gather(
gt_angle_residual_normalized, 1, assignments["per_prop_gt_inds"]
)
gt_angle_label_one_hot = torch.zeros_like(
angle_residual, dtype=torch.float32
)
gt_angle_label_one_hot.scatter_(2, gt_angle_label.unsqueeze(-1), 1)
angle_residual_for_gt_class = torch.sum(
angle_residual * gt_angle_label_one_hot, -1
)
angle_reg_loss = huber_loss(
angle_residual_for_gt_class - gt_angle_residual_normalized, delta=1.0
)
angle_reg_loss = (
angle_reg_loss * assignments["proposal_matched_mask"]
).sum()
angle_cls_loss /= targets["num_boxes"]
angle_reg_loss /= targets["num_boxes"]
else:
angle_cls_loss = torch.zeros(1, device=angle_logits.device).squeeze()
angle_reg_loss = torch.zeros(1, device=angle_logits.device).squeeze()
return {"loss_angle_cls": angle_cls_loss, "loss_angle_reg": angle_reg_loss}