in sam2/modeling/sam2_utils.py [0:0]
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, _, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)
points = points.to(device)
labels = labels.to(device)
return points, labels