def __call__()

in pytorch/sagemakercv/detection/roi_heads/box_head/loss.py [0:0]


    def __call__(self, class_logits, box_regression):
        """
        Computes the loss for Faster R-CNN.
        This requires that the subsample method has been called beforehand.

        Arguments:
            class_logits (list[Tensor])
            box_regression (list[Tensor])

        Returns:
            classification_loss (Tensor)
            box_loss (Tensor)
        """

        class_logits = cat(class_logits, dim=0)
        box_regression = cat(box_regression, dim=0)
        device = class_logits.device

        if not hasattr(self, "_proposals"):
            raise RuntimeError("subsample needs to be called before")

        proposals = self._proposals

        labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
        regression_targets = cat(
            [proposal.get_field("regression_targets") for proposal in proposals], dim=0
        )
        
        label_weights = cat(
            [proposal.get_field("label_weights") for proposal in proposals], dim=0
        )

        target_weights = cat(
            [proposal.get_field("target_weights") for proposal in proposals], dim=0
        )
        
        pos_matched_idxs = [proposal.get_field("pos_matched_idxs") for proposal in proposals]

        rois = torch.cat([a.bbox for a in proposals], dim=0)

        # classification_loss = F.cross_entropy(class_logits, labels)

        # get indices that correspond to the regression targets for
        # the corresponding ground truth labels, to be used with
        # advanced indexing
        pos_label_inds = torch.nonzero(labels > 0).squeeze(1)
        pos_labels = labels.index_select(0, pos_label_inds)
        if self.cls_agnostic_bbox_reg:
            map_inds = torch.tensor([4, 5, 6, 7], device=device).repeat(1, pos_labels.shape[0]).view(-1,4)
        else:
            map_inds = 4 * pos_labels[:, None] + torch.tensor(
                [0, 1, 2, 3], device=device)

        index_select_indices=((pos_label_inds[:,None]) * box_regression.size(1) + map_inds).view(-1)
        pos_box_pred_delta=box_regression.view(-1).index_select(0, index_select_indices).view(map_inds.shape[0], 
                                                                                              map_inds.shape[1]) 
        pos_box_target_delta = regression_targets.index_select(0, pos_label_inds)
        pos_rois = rois.index_select(0, pos_label_inds)
        
        if self.loss == "GIoULoss" and self.decode:
            pos_box_target = pos_box_target_delta
        else:
            pos_box_target = self.box_coder.decode(pos_box_target_delta, pos_rois)
        pos_box_pred = self.box_coder.decode(pos_box_pred_delta, pos_rois)
        
        bbox_inputs = [labels, label_weights, regression_targets, target_weights, pos_box_pred, pos_box_target,
                       pos_label_inds, pos_labels]

        if self.use_isr_p:
            labels, label_weights, regression_targets, target_weights = isr_p(
                class_logits,
                bbox_inputs,
                pos_matched_idxs,
                self.cls_loss)
        avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
        classification_loss = self.cls_loss(class_logits,
                                            labels,
                                            weight=label_weights,
                                            avg_factor=avg_factor
                                            )
        
        if self.loss == "SmoothL1Loss":
            box_loss = smooth_l1_loss(
                pos_box_pred_delta,
                pos_box_target_delta,
                weight=target_weights,
                size_average=False,
                beta=1,
            )
            box_loss = box_loss / labels.numel()
            # start = torch.cuda.Event(enable_timing=True)
            # end = torch.cuda.Event(enable_timing=True)
            # start.record()
            if self.carl:
                loss_carl = carl_loss(
                    class_logits,
                    pos_label_inds,
                    pos_labels,
                    pos_box_pred_delta,
                    pos_box_target_delta,
                    smooth_l1_loss,
                    k=1,
                    bias=0.2,
                    avg_factor=regression_targets.size(0),
                    num_class=80)
            # end.record()
            # torch.cuda.synchronize()
            # print("carl loss time: ", start.elapsed_time(end))
        elif self.loss == "GIoULoss":
            if pos_box_pred.size()[0] > 0:
                box_loss = self.giou_loss(
                    pos_box_pred,
                    pos_box_target,
                    weight=target_weights.index_select(0, pos_label_inds),
                    avg_factor=labels.numel()
                )
            else:
                box_loss = box_regression.sum() * 0
            if self.carl:
                loss_carl = carl_loss(
                    class_logits,
                    pos_label_inds,
                    pos_labels,
                    pos_box_pred,
                    pos_box_target,
                    self.giou_loss,
                    k=1,
                    bias=0.3,
                    avg_factor=regression_targets.size(0),
                    num_class=80)

        if self.carl:
            return classification_loss, box_loss, loss_carl
        else:
            return classification_loss, box_loss