def panoptic_stats()

in seamseg/utils/panoptic.py [0:0]


def panoptic_stats(msk_gt, cat_gt, panoptic_pred, num_classes, _num_stuff):
    # Move gt to CPU
    msk_gt, cat_gt = msk_gt.cpu(), cat_gt.cpu()
    msk_pred, cat_pred, _, iscrowd_pred = panoptic_pred

    # Convert crowd predictions to void
    msk_remap = msk_pred.new_zeros(cat_pred.numel())
    msk_remap[~iscrowd_pred] = torch.arange(
        0, (~iscrowd_pred).long().sum().item(), dtype=msk_remap.dtype, device=msk_remap.device)
    msk_pred = msk_remap[msk_pred]
    cat_pred = cat_pred[~iscrowd_pred]

    iou = msk_pred.new_zeros(num_classes, dtype=torch.double)
    tp = msk_pred.new_zeros(num_classes, dtype=torch.double)
    fp = msk_pred.new_zeros(num_classes, dtype=torch.double)
    fn = msk_pred.new_zeros(num_classes, dtype=torch.double)

    if cat_gt.numel() > 1:
        msk_gt = msk_gt.view(-1)
        msk_pred = msk_pred.view(-1)

        # Compute confusion matrix
        confmat = msk_pred.new_zeros(cat_gt.numel(), cat_pred.numel(), dtype=torch.double)
        confmat.view(-1).index_add_(0, msk_gt * cat_pred.numel() + msk_pred,
                                    confmat.new_ones(msk_gt.numel()))

        # track potentially valid FP, i.e. those that overlap with void_gt <= 0.5
        num_pred_pixels = confmat.sum(0)
        valid_fp = (confmat[0] / num_pred_pixels) <= 0.5

        # compute IoU without counting void pixels (both in gt and pred)
        _iou = confmat / ((num_pred_pixels - confmat[0]).unsqueeze(0) + confmat.sum(1).unsqueeze(1) - confmat)

        # flag TP matches, i.e. same class and iou > 0.5
        matches = ((cat_gt.unsqueeze(1) == cat_pred.unsqueeze(0)) & (_iou > 0.5))

        # remove potential match of void_gt against void_pred
        matches[0, 0] = 0

        _iou = _iou[matches]
        tp_i, _ = matches.max(1)
        fn_i = ~tp_i
        fn_i[0] = 0  # remove potential fn match due to void against void
        fp_i = ~matches.max(0)[0] & valid_fp
        fp_i[0] = 0  # remove potential fp match due to void against void

        # Compute per instance classes for each tp, fp, fn
        tp_cat = cat_gt[tp_i]
        fn_cat = cat_gt[fn_i]
        fp_cat = cat_pred[fp_i]

        # Accumulate per class counts
        if tp_cat.numel() > 0:
            tp.index_add_(0, tp_cat, tp.new_ones(tp_cat.numel()))
        if fp_cat.numel() > 0:
            fp.index_add_(0, fp_cat, fp.new_ones(fp_cat.numel()))
        if fn_cat.numel() > 0:
            fn.index_add_(0, fn_cat, fn.new_ones(fn_cat.numel()))
        if tp_cat.numel() > 0:
            iou.index_add_(0, tp_cat, _iou)

    # note else branch is not needed because if cat_gt has only void we don't penalize predictions
    return iou, tp, fp, fn