def filter_results()

in siammot/modelling/box_head/inference.py [0:0]


    def filter_results(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        # unwrap the boxlist to avoid additional overhead.
        # if we had multi-class NMS, we could perform this directly on the boxlist
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)
        device = scores.device

        assert (boxlist.has_field('ids'))
        ids = boxlist.get_field('ids')

        result = [self.create_empty_boxlist(device=device)
                  for _ in range(1, num_classes)]

        # Apply threshold on detection probabilities and apply NMS
        # Skip j = 0, because it's the background class
        inds_all = scores > self.score_thresh
        for j in range(1, num_classes):
            inds = inds_all[:, j].nonzero().squeeze(1)
            scores_j = scores[inds, j]
            boxes_j = boxes[inds, j * 4: (j + 1) * 4]
            ids_j = ids[inds]

            det_idx = ids_j < 0
            det_boxlist = BoxList(boxes_j[det_idx, :], boxlist.size, mode="xyxy")
            det_boxlist.add_field("scores", scores_j[det_idx])
            det_boxlist.add_field("ids", ids_j[det_idx])
            det_boxlist = boxlist_nms(det_boxlist, self.nms)

            track_idx = ids_j >= 0
            # track_box is available
            if torch.any(track_idx > 0):
                track_boxlist = BoxList(boxes_j[track_idx, :], boxlist.size, mode="xyxy")
                track_boxlist.add_field("scores", scores_j[track_idx])
                track_boxlist.add_field("ids", ids_j[track_idx])
                det_boxlist = cat_boxlist([det_boxlist, track_boxlist])

            num_labels = len(det_boxlist)
            det_boxlist.add_field(
                "labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
            )
            result[j-1] = det_boxlist

        result = cat_boxlist(result)
        return result