def __call__()

in siammot/modelling/track_head/EMM/track_loss.py [0:0]


    def __call__(self, locations, box_cls, box_regression, centerness, src, targets):
        """
        """

        cls_labels, reg_targets = self.prepare_targets(locations, src, targets)

        box_regression = (box_regression.permute(0, 2, 3, 1).contiguous()).view(-1, 4)
        box_regression_flatten = box_regression.view(-1, 4)
        reg_targets_flatten = reg_targets.view(-1, 4)
        cls_labels_flatten = cls_labels.view(-1)
        centerness_flatten = centerness.view(-1)

        in_box_inds = torch.nonzero(cls_labels_flatten > 0).squeeze(1)
        box_regression_flatten = box_regression_flatten[in_box_inds]
        reg_targets_flatten = reg_targets_flatten[in_box_inds]
        centerness_flatten = centerness_flatten[in_box_inds]

        box_cls = log_softmax(box_cls)
        cls_loss = select_cross_entropy_loss(box_cls, cls_labels_flatten)

        if in_box_inds.numel() > 0:
            centerness_targets = self.compute_centerness_targets(reg_targets_flatten)
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten,
                reg_targets_flatten,
                centerness_targets
            )
            centerness_loss = self.centerness_loss_func(
                centerness_flatten,
                centerness_targets
            )
        else:
            reg_loss = 0. * box_regression_flatten.sum()
            centerness_loss = 0. * centerness_flatten.sum()

        return self.loss_weight*cls_loss, self.loss_weight*reg_loss, self.loss_weight*centerness_loss