in pytorch/sagemakercv/detection/roi_heads/box_head/loss.py [0:0]
def __call__(self, class_logits, box_regression):
"""
Computes the loss for Faster R-CNN.
This requires that the subsample method has been called beforehand.
Arguments:
class_logits (list[Tensor])
box_regression (list[Tensor])
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""
class_logits = cat(class_logits, dim=0)
box_regression = cat(box_regression, dim=0)
device = class_logits.device
if not hasattr(self, "_proposals"):
raise RuntimeError("subsample needs to be called before")
proposals = self._proposals
labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
regression_targets = cat(
[proposal.get_field("regression_targets") for proposal in proposals], dim=0
)
label_weights = cat(
[proposal.get_field("label_weights") for proposal in proposals], dim=0
)
target_weights = cat(
[proposal.get_field("target_weights") for proposal in proposals], dim=0
)
pos_matched_idxs = [proposal.get_field("pos_matched_idxs") for proposal in proposals]
rois = torch.cat([a.bbox for a in proposals], dim=0)
# classification_loss = F.cross_entropy(class_logits, labels)
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
pos_label_inds = torch.nonzero(labels > 0).squeeze(1)
pos_labels = labels.index_select(0, pos_label_inds)
if self.cls_agnostic_bbox_reg:
map_inds = torch.tensor([4, 5, 6, 7], device=device).repeat(1, pos_labels.shape[0]).view(-1,4)
else:
map_inds = 4 * pos_labels[:, None] + torch.tensor(
[0, 1, 2, 3], device=device)
index_select_indices=((pos_label_inds[:,None]) * box_regression.size(1) + map_inds).view(-1)
pos_box_pred_delta=box_regression.view(-1).index_select(0, index_select_indices).view(map_inds.shape[0],
map_inds.shape[1])
pos_box_target_delta = regression_targets.index_select(0, pos_label_inds)
pos_rois = rois.index_select(0, pos_label_inds)
if self.loss == "GIoULoss" and self.decode:
pos_box_target = pos_box_target_delta
else:
pos_box_target = self.box_coder.decode(pos_box_target_delta, pos_rois)
pos_box_pred = self.box_coder.decode(pos_box_pred_delta, pos_rois)
bbox_inputs = [labels, label_weights, regression_targets, target_weights, pos_box_pred, pos_box_target,
pos_label_inds, pos_labels]
if self.use_isr_p:
labels, label_weights, regression_targets, target_weights = isr_p(
class_logits,
bbox_inputs,
pos_matched_idxs,
self.cls_loss)
avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
classification_loss = self.cls_loss(class_logits,
labels,
weight=label_weights,
avg_factor=avg_factor
)
if self.loss == "SmoothL1Loss":
box_loss = smooth_l1_loss(
pos_box_pred_delta,
pos_box_target_delta,
weight=target_weights,
size_average=False,
beta=1,
)
box_loss = box_loss / labels.numel()
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
if self.carl:
loss_carl = carl_loss(
class_logits,
pos_label_inds,
pos_labels,
pos_box_pred_delta,
pos_box_target_delta,
smooth_l1_loss,
k=1,
bias=0.2,
avg_factor=regression_targets.size(0),
num_class=80)
# end.record()
# torch.cuda.synchronize()
# print("carl loss time: ", start.elapsed_time(end))
elif self.loss == "GIoULoss":
if pos_box_pred.size()[0] > 0:
box_loss = self.giou_loss(
pos_box_pred,
pos_box_target,
weight=target_weights.index_select(0, pos_label_inds),
avg_factor=labels.numel()
)
else:
box_loss = box_regression.sum() * 0
if self.carl:
loss_carl = carl_loss(
class_logits,
pos_label_inds,
pos_labels,
pos_box_pred,
pos_box_target,
self.giou_loss,
k=1,
bias=0.3,
avg_factor=regression_targets.size(0),
num_class=80)
if self.carl:
return classification_loss, box_loss, loss_carl
else:
return classification_loss, box_loss