in seamseg/utils/coco_ap.py [0:0]
def process_prediction(bbx_pred, cls_pred, obj_pred, msk_pred, img_size, idx, original_size):
# Move everything to CPU
bbx_pred, cls_pred, obj_pred = (t.cpu() for t in (bbx_pred, cls_pred, obj_pred))
msk_pred = msk_pred.cpu() if msk_pred is not None else None
if msk_pred is not None:
if isinstance(msk_pred, torch.Tensor):
# ROI-stile prediction
bbx_inv = invert_roi_bbx(bbx_pred, list(msk_pred.shape[-2:]), list(img_size))
bbx_idx = torch.arange(0, msk_pred.size(0), dtype=torch.long)
msk_pred = roi_sampling(msk_pred.unsqueeze(1).sigmoid(), bbx_inv, bbx_idx, list(img_size), padding="zero")
msk_pred = msk_pred.squeeze(1) > 0.5
elif isinstance(msk_pred, PackedSequence):
# Seeds-style prediction
msk_pred.data = msk_pred.data > 0.5
msk_pred_exp = msk_pred.data.new_zeros(len(msk_pred), img_size[0], img_size[1])
for it, (msk_pred_i, bbx_pred_i) in enumerate(zip(msk_pred, bbx_pred)):
i, j = int(bbx_pred_i[0].item()), int(bbx_pred_i[1].item())
msk_pred_exp[it, i:i + msk_pred_i.size(0), j:j + msk_pred_i.size(1)] = msk_pred_i
msk_pred = msk_pred_exp
# Convert bbx and redo clamping
bbx_pred[:, [0, 2]] = (bbx_pred[:, [0, 2]] / img_size[0] * original_size[0]).clamp(min=0, max=original_size[0])
bbx_pred[:, [1, 3]] = (bbx_pred[:, [1, 3]] / img_size[1] * original_size[1]).clamp(min=0, max=original_size[1])
bbx_pred_size = bbx_pred[:, 2:] - bbx_pred[:, :2]
outs = []
for i, (bbx_pred_i, bbx_pred_size_i, cls_pred_i, obj_pred_i) in \
enumerate(zip(bbx_pred, bbx_pred_size, cls_pred, obj_pred)):
out = dict(image_id=idx, category_id=int(cls_pred_i.item()), score=float(obj_pred_i.item()))
out["bbox"] = [
float(bbx_pred_i[1].item()),
float(bbx_pred_i[0].item()),
float(bbx_pred_size_i[1].item()),
float(bbx_pred_size_i[0].item()),
]
# Expand and convert mask if present
if msk_pred is not None:
segmentation = Image.fromarray(msk_pred[i].numpy()).resize(original_size[::-1], Image.NEAREST)
out["segmentation"] = mask_encode(np.asfortranarray(np.array(segmentation)))
out["segmentation"]["counts"] = str(out["segmentation"]["counts"], "utf-8")
outs.append(out)
return outs