in pytorch/sagemakercv/core/balanced_positive_negative_sampler.py [0:0]
def __call__(self, matched_idxs, is_rpn=0, objectness=None):
"""
Arguments:
matched idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as
positives.
Returns:
pos_idx (list[tensor])
neg_idx (list[tensor])
Returns two lists of binary masks for each image.
The first list contains the positive elements that were selected,
and the second list the negative example.
"""
num_images = len(matched_idxs)
if num_images == 1:
pos_idx = []
neg_idx = []
matched_idxs = [matched_idxs.view(-1)]
# there is actually only 1 iteration of this for loop, but keeping the loop for completeness
for matched_idxs_per_image in matched_idxs:
if objectness is not None:
objectness = objectness.view(-1)
positive = torch.nonzero((matched_idxs_per_image >= 1)*(objectness > -1) ).squeeze(1)
negative = torch.nonzero((matched_idxs_per_image == 0)*(objectness > -1)).squeeze(1)
else:
positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
num_neg = self.batch_size_per_image - num_pos
# protect against not enough negative examples
num_neg = min(negative.numel(), num_neg)
# randomly select positive and negative examples
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
pos_idx_per_image = positive.index_select(0, perm1)
neg_idx_per_image = negative.index_select(0, perm2)
# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.bool
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.bool
)
pos_idx_per_image_mask.index_fill_(0, pos_idx_per_image, 1)
neg_idx_per_image_mask.index_fill_(0, neg_idx_per_image, 1)
pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
## this implements a batched random subsampling using a tensor of random numbers and sorting
if is_rpn:
num_anchors_per_im = matched_idxs[0].size(0)
num_images = len(matched_idxs)
matched_idxs_cat = matched_idxs
device = matched_idxs_cat.device
## generate a mask for positive samples
pos_samples_mask = matched_idxs_cat >= 1
num_pos_samples = pos_samples_mask.sum(dim=1)
num_pos_samples_cum = num_pos_samples.cumsum(dim=0)
max_pos_samples = torch.max(num_pos_samples)
## we are generating a tensor of consecutive numbers for each row.
consec = torch.arange(max_pos_samples, device = device).repeat(num_images,1)
mask_to_hide = consec >= num_pos_samples.view(num_images,1)
## generate a tensor of random numbers, than fill the non-valid elements with 2 so
## they are at the end when sorted
rand_nums_batched = torch.rand([num_images, max_pos_samples], device=device)
rand_nums_batched.masked_fill_(mask_to_hide, 2)
rand_perm = rand_nums_batched.argsort(dim=1)
max_pos_allowed = int(self.batch_size_per_image * self.positive_fraction)
num_pos_subsamples = num_pos_samples.clamp(max=max_pos_allowed)
subsampling_mask = rand_perm < num_pos_subsamples.view(num_images,1)
if num_images>1:
consec[1:,:] = consec[1:,:] + num_pos_samples_cum[:-1].view(num_images-1,1)
sampling_inds = consec.masked_select(subsampling_mask)
pos_samples_inds = pos_samples_mask.view(-1).nonzero().squeeze(1)
pos_subsampled_inds = pos_samples_inds[sampling_inds]
## do the same for negative samples as well
neg_samples_mask = matched_idxs_cat == 0
num_neg_samples = neg_samples_mask.sum(dim=1)
num_neg_samples_cum = num_neg_samples.cumsum(dim=0)
max_neg_samples = torch.max(num_neg_samples)
consec = torch.arange(max_neg_samples, device = device)
consec = consec.repeat(num_images,1)
mask_to_hide = consec >= num_neg_samples.view(num_images,1)
rand_nums_batched = torch.rand([num_images, max_neg_samples], device=device)
rand_nums_batched.masked_fill_(mask_to_hide, 2)
rand_perm = rand_nums_batched.argsort(dim=1)
num_subsamples = torch.min(num_neg_samples, self.batch_size_per_image - num_pos_subsamples)
subsampling_mask = rand_perm < num_subsamples.view(num_images,1)
if num_images>1:
consec[1:,:] = consec[1:,:] + num_neg_samples_cum[:-1].view(num_images-1,1)
sampling_inds = consec.masked_select(subsampling_mask)
neg_samples_inds = neg_samples_mask.view(-1).nonzero().squeeze(1)
neg_subsampled_inds = neg_samples_inds[sampling_inds]
return pos_subsampled_inds, neg_subsampled_inds
else:
matched_idxs_cat = matched_idxs
device = matched_idxs_cat.device
if objectness is not None:
pos_samples_mask =( matched_idxs_cat >= 1) * (objectness > -1)
else:
pos_samples_mask =( matched_idxs_cat >= 1)
num_pos_samples = pos_samples_mask.sum(dim=1)
num_pos_samples_cum = num_pos_samples.cumsum(dim=0)
max_pos_samples = torch.max(num_pos_samples)
consec = torch.arange(max_pos_samples, device = device).repeat(num_images,1)
mask_to_hide = consec >= num_pos_samples.view(num_images,1)
rand_nums_batched = torch.rand([num_images, max_pos_samples], device=device)
rand_nums_batched.masked_fill_(mask_to_hide, 2)
rand_perm = rand_nums_batched.argsort(dim=1)
max_pos_allowed = int(self.batch_size_per_image * self.positive_fraction)
num_pos_subsamples = num_pos_samples.clamp(max=max_pos_allowed)
subsampling_mask = rand_perm < num_pos_subsamples.view(num_images,1)
if num_images>1:
consec[1:,:] = consec[1:,:] + num_pos_samples_cum[:-1].view(num_images-1,1)
sampling_inds = consec.masked_select(subsampling_mask)
pos_samples_inds = pos_samples_mask.view(-1).nonzero().squeeze(1)
pos_subsampled_inds = pos_samples_inds[sampling_inds]
if objectness is not None:
neg_samples_mask = (matched_idxs_cat == 0) *( objectness>-1)
else:
neg_samples_mask = (matched_idxs_cat == 0)
num_neg_samples = neg_samples_mask.sum(dim=1)
num_neg_samples_cum = num_neg_samples.cumsum(dim=0)
max_neg_samples = torch.max(num_neg_samples)
consec = torch.arange(max_neg_samples, device = device)
consec = consec.repeat(num_images,1)
mask_to_hide = consec >= num_neg_samples.view(num_images,1)
rand_nums_batched = torch.rand([num_images, max_neg_samples], device=device)
rand_nums_batched.masked_fill_(mask_to_hide, 2)
rand_perm = rand_nums_batched.argsort(dim=1)
num_subsamples = torch.min(num_neg_samples, self.batch_size_per_image - num_pos_subsamples)
subsampling_mask = rand_perm < num_subsamples.view(num_images,1)
if num_images>1:
consec[1:,:] = consec[1:,:] + num_neg_samples_cum[:-1].view(num_images-1,1)
sampling_inds = consec.masked_select(subsampling_mask)
neg_samples_inds = neg_samples_mask.view(-1).nonzero().squeeze(1)
neg_subsampled_inds = neg_samples_inds[sampling_inds]
return pos_subsampled_inds, neg_subsampled_inds, num_pos_subsamples, num_subsamples