in seamseg/algos/rpn.py [0:0]
def __call__(self, anchors, bbx, iscrowd, valid_size):
"""Match anchors to ground truth boxes
Parameters
----------
anchors : torch.Tensor
Tensors of anchor bounding boxes with shapes M x 4
bbx : sequence of torch.Tensor
Sequence of N tensors of ground truth bounding boxes with shapes M_i x 4, entries can be None
iscrowd : sequence of torch.Tensor
Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes
(shapes K_i x 4), entries can be None
valid_size : list of tuple of int
List of N valid image sizes in input coordinates
Returns
-------
match : torch.Tensor
Tensor of matching results with shape N x M, with the following semantic:
- match[i, j] == -2: the j-th anchor in image i is void
- match[i, j] == -1: the j-th anchor in image i is negative
- match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k]
"""
match = []
for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size):
# Default labels: everything is void
match_i = anchors.new_full((anchors.size(0),), -2, dtype=torch.long)
try:
# Find anchors that are entirely within the original image area
valid = self._is_inside(anchors, valid_size_i)
# Check overlap with crowd
if self.void_threshold != 0 and iscrowd_i is not None:
if iscrowd_i.dtype == torch.uint8:
overlap = mask_overlap(anchors, iscrowd_i)
else:
overlap = bbx_overlap(anchors, iscrowd_i)
overlap, _ = overlap.max(dim=1)
valid = valid & (overlap < self.void_threshold)
if not valid.any().item():
raise Empty
valid_anchors = anchors[valid]
if bbx_i is not None:
max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0))
max_a2g_idx = bbx_i.new_full((valid_anchors.size(0),), -1, dtype=torch.long)
max_g2a_iou = []
max_g2a_idx = []
# Calculate assignments iteratively to save memory
for j, bbx_i_j in enumerate(torch.split(bbx_i, CHUNK_SIZE, dim=0)):
iou = ious(valid_anchors, bbx_i_j)
# Anchor -> GT
iou_max, iou_idx = iou.max(dim=1)
replace_idx = iou_max > max_a2g_iou
max_a2g_idx[replace_idx] = iou_idx[replace_idx] + j * CHUNK_SIZE
max_a2g_iou[replace_idx] = iou_max[replace_idx]
# GT -> Anchor
max_g2a_iou_j, max_g2a_idx_j = iou.transpose(0, 1).max(dim=1)
max_g2a_iou.append(max_g2a_iou_j)
max_g2a_idx.append(max_g2a_idx_j)
del iou
max_g2a_iou = torch.cat(max_g2a_iou, dim=0)
max_g2a_idx = torch.cat(max_g2a_idx, dim=0)
a2g_pos = max_a2g_iou >= self.pos_threshold
a2g_neg = max_a2g_iou < self.neg_threshold
g2a_pos = max_g2a_iou > 0
valid_match = valid_anchors.new_full((valid_anchors.size(0),), -2, dtype=torch.long)
valid_match[a2g_pos] = max_a2g_idx[a2g_pos]
valid_match[a2g_neg] = -1
valid_match[max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze()
else:
# No ground truth boxes for this image: everything that is not void is negative
valid_match = valid_anchors.new_full((valid_anchors.size(0),), -1, dtype=torch.long)
# Subsample positives and negatives
self._subsample(valid_match)
match_i[valid] = valid_match
except Empty:
pass
match.append(match_i)
return torch.stack(match, dim=0)