in pytorch/sagemakercv/detection/rpn/inference.py [0:0]
def forward(self, anchors, objectness, box_regression, image_shapes_cat, targets=None):
"""
Arguments:
anchors: list[list[BoxList]]
objectness: list[tensor]
box_regression: list[tensor]
Returns:
boxlists (list[BoxList]): the post-processed anchors, after
applying box decoding and NMS
"""
device = anchors[0].device #this is the batched anchors tensor
N, A, H_max, W_max, num_fmaps, num_anchors_per_level, fmap_size_cat, num_max_proposals, num_max_props_tensor = self.get_constant_tensors(anchors, objectness)
# initialize batched objectness, regression tensors and then form them
batched_objectness_tensor = -1e6 * torch.ones([num_fmaps, N, A * H_max * W_max], \
dtype = objectness[0].dtype, device=objectness[0].device)
batched_regression_tensor = -1 * torch.ones([num_fmaps, N, 4 * A * H_max * W_max], \
dtype = objectness[0].dtype, device=objectness[0].device)
for i in range(num_fmaps):
H, W = objectness[i].shape[2], objectness[i].shape[3]
batched_objectness_tensor[i,:,:(A * H * W)] = objectness[i].reshape(N, -1)
batched_regression_tensor[i,:,:(4 * A * H * W)] = box_regression[i].reshape(N, -1)
batched_objectness_tensor = batched_objectness_tensor.reshape(num_fmaps * N, -1)
batched_objectness_tensor = batched_objectness_tensor.sigmoid()
batched_objectness_topk, topk_idx = batched_objectness_tensor.topk(self.pre_nms_top_n, dim=1, sorted=True)
batched_anchor_tensor, image_shapes = anchors[0], anchors[2]
# generate proposals using a batched kernel
proposals_gen, objectness_gen, keep_gen = _C.GeneratePreNMSUprightBoxesBatched(
N,
A,
H_max*W_max,
A*H_max*W_max,
fmap_size_cat,
num_anchors_per_level,
topk_idx,
batched_objectness_topk.float(), # Need to cast these as kernel doesn't support fp16
batched_regression_tensor.float(),
batched_anchor_tensor,
image_shapes_cat,
self.pre_nms_top_n,
self.min_size,
self.box_coder.bbox_xform_clip,
True)
# keep is padded with 0s for image,fmap pairs where num_proposals<self.pre_nms_top_n
keep_gen = keep_gen.reshape(N * num_fmaps, self.pre_nms_top_n)
proposals_gen = proposals_gen.reshape(N * num_fmaps * self.pre_nms_top_n, 4)
# perform batched NMS kernel
keep_nms_batched = _C.nms_batched(proposals_gen, num_max_proposals, num_max_props_tensor, keep_gen, self.nms_thresh).bool()
keep_nms_batched = keep_nms_batched.reshape(num_fmaps, N, -1)
keep = keep_nms_batched.reshape(num_fmaps, N, self.pre_nms_top_n)
# switch leading two dimensions from (f_map, image) to (image, fmap)
proposals_gen = proposals_gen.reshape(num_fmaps, N, self.pre_nms_top_n, 4)
objectness_gen = objectness_gen.reshape(num_fmaps, N, self.pre_nms_top_n)
keep = keep.permute(1, 0, 2)
objectness_gen = objectness_gen.permute(1, 0, 2)
proposals_gen = proposals_gen.permute(1, 0, 2, 3)
if not self.training:
# split batched results back into boxlists
keep = keep.split(1)
objectness_gen = objectness_gen.split(1)
proposals_gen = proposals_gen.split(1)
boxlists=[]
for i in range(N):
boxlist = BoxList(proposals_gen[i][keep[i]], image_shapes[i], mode="xyxy")
boxlist.add_field("objectness", objectness_gen[i][keep[i]])
boxlists.append(boxlist)
if num_fmaps > 1:
boxlists = self.select_over_all_levels(boxlists)
return boxlists
if self.per_image_search: # TO-DO: aren't per image and per batch search the same when N == 1
# Post NMS per image search
objectness_gen.masked_fill_(~keep, -1)
proposals_gen.masked_fill_((~keep).unsqueeze(3), -1)
proposals_gen = proposals_gen.reshape(N,-1,4)
objectness_gen = objectness_gen.reshape(N,-1)
objectness = objectness_gen
_, inds_post_nms_top_n = torch.topk(objectness, self.fpn_post_nms_top_n, dim=1, sorted=False)
inds_post_nms_top_n, _ = inds_post_nms_top_n.sort()
objectness = torch.gather(objectness_gen, dim=1, index=inds_post_nms_top_n)
batch_inds = torch.arange(N, device=device)[:,None]
proposals = proposals_gen[batch_inds, inds_post_nms_top_n]
else:
# Post NMS per batch search
objectness_gen = objectness_gen * keep.float()
objectness_gen = objectness_gen.reshape(-1)
objectness_kept = objectness_gen
num_keeps = (keep.flatten() != 0).sum().int()
_, inds_post_nms_top_n = torch.topk(objectness_kept, min(self.fpn_post_nms_top_n,num_keeps), dim=0, sorted=False)
inds_post_nms_top_n, _ = inds_post_nms_top_n.sort()
objectness_kept = objectness_gen[inds_post_nms_top_n]
proposals_kept = proposals_gen.reshape(-1 ,4)[inds_post_nms_top_n]
inds_mask = torch.zeros_like(objectness_gen, dtype=torch.uint8)
inds_mask[inds_post_nms_top_n] = 1
inds_mask_per_image = inds_mask.reshape(N, -1)
num_kept_per_image = list(inds_mask_per_image.sum(dim=1))
if N > 1:
proposals = pad_sequence(proposals_kept.split(num_kept_per_image, dim=0), batch_first=True, padding_value=-1)
objectness = pad_sequence(objectness_kept.split(num_kept_per_image, dim=0), batch_first=True, padding_value=-1)
else:
proposals = proposals_kept.unsqueeze(0)
objectness = objectness_kept.unsqueeze(0)
## make a batched tensor for targets as well
target_bboxes = [box.bbox for box in targets]
## objectness will be used as a mask to filter out invalid boxes, e.g. with score -1
target_objectness = [torch.ones(len(gt_box), device=targets[0].bbox.device) for gt_box in targets]
if N > 1:
target_bboxes = pad_sequence(target_bboxes, batch_first=True, padding_value=-1)
target_objectness = pad_sequence(target_objectness, batch_first=True, padding_value=-1)
else:
target_bboxes = target_bboxes[0].unsqueeze(0)
target_objectness = target_objectness[0].unsqueeze(0)
proposals = torch.cat([proposals, target_bboxes], dim=1)
objectness = torch.cat([objectness, target_objectness], dim=1)
return [proposals, objectness, image_shapes]