def forward()

in pytorch/sagemakercv/detection/rpn/inference.py [0:0]


    def forward(self, anchors, objectness, box_regression, image_shapes_cat, targets=None):
        """
        Arguments:
            anchors: list[list[BoxList]]
            objectness: list[tensor]
            box_regression: list[tensor]
        Returns:
            boxlists (list[BoxList]): the post-processed anchors, after
                applying box decoding and NMS
        """

        device = anchors[0].device #this is the batched anchors tensor

        N, A, H_max, W_max, num_fmaps, num_anchors_per_level, fmap_size_cat, num_max_proposals, num_max_props_tensor = self.get_constant_tensors(anchors, objectness)

        # initialize batched objectness, regression tensors and then form them
        batched_objectness_tensor = -1e6 * torch.ones([num_fmaps, N, A * H_max * W_max],  \
                                                        dtype = objectness[0].dtype, device=objectness[0].device)
        batched_regression_tensor = -1 * torch.ones([num_fmaps, N, 4 * A * H_max * W_max], \
                                                        dtype = objectness[0].dtype, device=objectness[0].device)
        for i in range(num_fmaps):
            H, W = objectness[i].shape[2], objectness[i].shape[3]
            batched_objectness_tensor[i,:,:(A * H * W)] = objectness[i].reshape(N, -1)
            batched_regression_tensor[i,:,:(4 * A * H * W)] = box_regression[i].reshape(N, -1)
          
        batched_objectness_tensor = batched_objectness_tensor.reshape(num_fmaps * N, -1)
        batched_objectness_tensor = batched_objectness_tensor.sigmoid()
        batched_objectness_topk, topk_idx = batched_objectness_tensor.topk(self.pre_nms_top_n, dim=1, sorted=True)

        batched_anchor_tensor, image_shapes = anchors[0], anchors[2]

        # generate proposals using a batched kernel
        proposals_gen, objectness_gen, keep_gen = _C.GeneratePreNMSUprightBoxesBatched(
                                N,
                                A,
                                H_max*W_max,
                                A*H_max*W_max,
                                fmap_size_cat,
                                num_anchors_per_level,
                                topk_idx,
                                batched_objectness_topk.float(),    # Need to cast these as kernel doesn't support fp16
                                batched_regression_tensor.float(),
                                batched_anchor_tensor,
                                image_shapes_cat,
                                self.pre_nms_top_n,
                                self.min_size,
                                self.box_coder.bbox_xform_clip,
                                True)
    
        # keep is padded with 0s for image,fmap pairs where num_proposals<self.pre_nms_top_n

        keep_gen = keep_gen.reshape(N * num_fmaps, self.pre_nms_top_n)
        proposals_gen = proposals_gen.reshape(N * num_fmaps * self.pre_nms_top_n, 4)

        # perform batched NMS kernel
        keep_nms_batched = _C.nms_batched(proposals_gen, num_max_proposals, num_max_props_tensor, keep_gen, self.nms_thresh).bool()
        keep_nms_batched = keep_nms_batched.reshape(num_fmaps, N, -1)
        keep = keep_nms_batched.reshape(num_fmaps, N, self.pre_nms_top_n)
        # switch leading two dimensions from (f_map, image) to (image, fmap)
        proposals_gen = proposals_gen.reshape(num_fmaps, N, self.pre_nms_top_n, 4)
        objectness_gen = objectness_gen.reshape(num_fmaps, N, self.pre_nms_top_n)
        keep = keep.permute(1, 0, 2)
        objectness_gen = objectness_gen.permute(1, 0, 2)
        proposals_gen = proposals_gen.permute(1, 0, 2, 3)
        if not self.training:
            # split batched results back into boxlists
            keep = keep.split(1)
            objectness_gen = objectness_gen.split(1)
            proposals_gen = proposals_gen.split(1)
            boxlists=[]
            for i in range(N):
                boxlist = BoxList(proposals_gen[i][keep[i]], image_shapes[i], mode="xyxy")
                boxlist.add_field("objectness", objectness_gen[i][keep[i]])
                boxlists.append(boxlist)
            if num_fmaps > 1:
                boxlists = self.select_over_all_levels(boxlists)
            return boxlists

        if self.per_image_search: # TO-DO: aren't per image and per batch search the same when N == 1
            # Post NMS per image search
            objectness_gen.masked_fill_(~keep, -1)
            proposals_gen.masked_fill_((~keep).unsqueeze(3), -1)
            proposals_gen = proposals_gen.reshape(N,-1,4)
            objectness_gen = objectness_gen.reshape(N,-1)
            objectness = objectness_gen

            _, inds_post_nms_top_n = torch.topk(objectness, self.fpn_post_nms_top_n, dim=1, sorted=False)
            inds_post_nms_top_n, _ = inds_post_nms_top_n.sort()
            objectness = torch.gather(objectness_gen, dim=1, index=inds_post_nms_top_n)
            batch_inds = torch.arange(N, device=device)[:,None]
            proposals = proposals_gen[batch_inds, inds_post_nms_top_n]
        else:
            # Post NMS per batch search
            objectness_gen = objectness_gen * keep.float()
            objectness_gen = objectness_gen.reshape(-1)
            objectness_kept = objectness_gen
            num_keeps = (keep.flatten() != 0).sum().int()

            _, inds_post_nms_top_n = torch.topk(objectness_kept, min(self.fpn_post_nms_top_n,num_keeps), dim=0, sorted=False)
            inds_post_nms_top_n, _ = inds_post_nms_top_n.sort()
            objectness_kept = objectness_gen[inds_post_nms_top_n]
            proposals_kept = proposals_gen.reshape(-1 ,4)[inds_post_nms_top_n]

            inds_mask = torch.zeros_like(objectness_gen, dtype=torch.uint8)
            inds_mask[inds_post_nms_top_n] = 1
            inds_mask_per_image = inds_mask.reshape(N, -1)

            num_kept_per_image = list(inds_mask_per_image.sum(dim=1))
            if N > 1:
                proposals = pad_sequence(proposals_kept.split(num_kept_per_image, dim=0), batch_first=True, padding_value=-1)
                objectness = pad_sequence(objectness_kept.split(num_kept_per_image, dim=0), batch_first=True, padding_value=-1)
            else:
                proposals = proposals_kept.unsqueeze(0)
                objectness = objectness_kept.unsqueeze(0)

        ## make a batched tensor for targets as well
        target_bboxes = [box.bbox for box in targets]
        ## objectness will be used as a mask to filter out invalid boxes, e.g. with score -1
        target_objectness = [torch.ones(len(gt_box), device=targets[0].bbox.device) for gt_box in targets]
        if N > 1: 
            target_bboxes = pad_sequence(target_bboxes, batch_first=True, padding_value=-1)
            target_objectness = pad_sequence(target_objectness, batch_first=True, padding_value=-1)
        else: 
            target_bboxes = target_bboxes[0].unsqueeze(0)
            target_objectness = target_objectness[0].unsqueeze(0)
        proposals = torch.cat([proposals, target_bboxes], dim=1)
        objectness = torch.cat([objectness, target_objectness], dim=1)
        return [proposals, objectness, image_shapes]