def process_prediction()

in seamseg/utils/coco_ap.py [0:0]


def process_prediction(bbx_pred, cls_pred, obj_pred, msk_pred, img_size, idx, original_size):
    # Move everything to CPU
    bbx_pred, cls_pred, obj_pred = (t.cpu() for t in (bbx_pred, cls_pred, obj_pred))
    msk_pred = msk_pred.cpu() if msk_pred is not None else None

    if msk_pred is not None:
        if isinstance(msk_pred, torch.Tensor):
            # ROI-stile prediction
            bbx_inv = invert_roi_bbx(bbx_pred, list(msk_pred.shape[-2:]), list(img_size))
            bbx_idx = torch.arange(0, msk_pred.size(0), dtype=torch.long)
            msk_pred = roi_sampling(msk_pred.unsqueeze(1).sigmoid(), bbx_inv, bbx_idx, list(img_size), padding="zero")
            msk_pred = msk_pred.squeeze(1) > 0.5
        elif isinstance(msk_pred, PackedSequence):
            # Seeds-style prediction
            msk_pred.data = msk_pred.data > 0.5
            msk_pred_exp = msk_pred.data.new_zeros(len(msk_pred), img_size[0], img_size[1])

            for it, (msk_pred_i, bbx_pred_i) in enumerate(zip(msk_pred, bbx_pred)):
                i, j = int(bbx_pred_i[0].item()), int(bbx_pred_i[1].item())
                msk_pred_exp[it, i:i + msk_pred_i.size(0), j:j + msk_pred_i.size(1)] = msk_pred_i

            msk_pred = msk_pred_exp

    # Convert bbx and redo clamping
    bbx_pred[:, [0, 2]] = (bbx_pred[:, [0, 2]] / img_size[0] * original_size[0]).clamp(min=0, max=original_size[0])
    bbx_pred[:, [1, 3]] = (bbx_pred[:, [1, 3]] / img_size[1] * original_size[1]).clamp(min=0, max=original_size[1])
    bbx_pred_size = bbx_pred[:, 2:] - bbx_pred[:, :2]

    outs = []
    for i, (bbx_pred_i, bbx_pred_size_i, cls_pred_i, obj_pred_i) in \
            enumerate(zip(bbx_pred, bbx_pred_size, cls_pred, obj_pred)):
        out = dict(image_id=idx, category_id=int(cls_pred_i.item()), score=float(obj_pred_i.item()))

        out["bbox"] = [
            float(bbx_pred_i[1].item()),
            float(bbx_pred_i[0].item()),
            float(bbx_pred_size_i[1].item()),
            float(bbx_pred_size_i[0].item()),
        ]

        # Expand and convert mask if present
        if msk_pred is not None:
            segmentation = Image.fromarray(msk_pred[i].numpy()).resize(original_size[::-1], Image.NEAREST)

            out["segmentation"] = mask_encode(np.asfortranarray(np.array(segmentation)))
            out["segmentation"]["counts"] = str(out["segmentation"]["counts"], "utf-8")

        outs.append(out)

    return outs