def __call__()

in evaluation/tiny_benchmark/maskrcnn_benchmark/modeling/rpn/locnet/loss.py [0:0]


    def __call__(self, locations, box_cls, box_regression, centerness, targets):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        N = box_cls[0].size(0)
        num_classes = box_cls[0].size(1)
        labels, reg_targets = self.prepare_targets(locations, targets)

        if self.debug_vis_labels: show_label_map(labels, box_cls)

        box_cls_flatten = []
        box_regression_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1, num_classes))  # changed
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)

        # class loss
        label_flatten_max = labels_flatten.max(dim=1)[0]
        pos_inds = torch.nonzero(label_flatten_max > 0).squeeze(1)
        pos_sum = labels_flatten.sum()
        cls_losses = self.cls_loss_func(
            box_cls_flatten,
            labels_flatten
        )
        if isinstance(cls_losses, (list,)):
            for i in range(len(cls_losses)):
                if self.cls_divide_pos_num:
                    cls_losses[i] /= (pos_inds.numel() + N)  # add N to avoid dividing by a zero
                elif self.cls_divide_pos_sum:
                    cls_losses[i] /= (pos_sum + N)
        else:
            if self.cls_divide_pos_num:
                cls_losses /= (pos_inds.numel() + N)  # add N to avoid dividing by a zero
            elif self.cls_divide_pos_sum:
                cls_losses /= (pos_sum + N)

        # reg loss
        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]

        if pos_inds.numel() > 0:
            if self.centerness_weight_reg:
                reg_weights = centerness_targets = self.prepare_targets.compute_centerness_targets(reg_targets_flatten)
            else:
                reg_weights = label_flatten_max[pos_inds]
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten,
                reg_targets_flatten,
                reg_weights
            )
        else:
            reg_loss = box_regression_flatten.sum()

        if isinstance(cls_losses, (list,)):
            losses = {"loss_cls{}".format(i): cls_loss * self.cls_loss_weight for i, cls_loss in enumerate(cls_losses)}
            losses['loss_reg'] = reg_loss
        else:
            losses = {
                "loss_cls": cls_losses * self.cls_loss_weight,
                "loss_reg": reg_loss
            }

        # centerness loss
        if centerness is not None:
            centerness_flatten = [centerness[l].reshape(-1) for l in range(len(centerness))]
            centerness_flatten = torch.cat(centerness_flatten, dim=0)
            centerness_flatten = centerness_flatten[pos_inds]

            if pos_inds.numel() > 0:
                centerness_loss = self.centerness_loss_func(
                    centerness_flatten,
                    centerness_targets
                )
            else:
                centerness_loss = centerness_flatten.sum()

            losses["loss_centerness"] = centerness_loss
        return losses