in siammot/modelling/box_head/inference.py [0:0]
def filter_results(self, boxlist, num_classes):
"""Returns bounding-box detection results by thresholding on scores and
applying non-maximum suppression (NMS).
"""
# unwrap the boxlist to avoid additional overhead.
# if we had multi-class NMS, we could perform this directly on the boxlist
boxes = boxlist.bbox.reshape(-1, num_classes * 4)
scores = boxlist.get_field("scores").reshape(-1, num_classes)
device = scores.device
assert (boxlist.has_field('ids'))
ids = boxlist.get_field('ids')
result = [self.create_empty_boxlist(device=device)
for _ in range(1, num_classes)]
# Apply threshold on detection probabilities and apply NMS
# Skip j = 0, because it's the background class
inds_all = scores > self.score_thresh
for j in range(1, num_classes):
inds = inds_all[:, j].nonzero().squeeze(1)
scores_j = scores[inds, j]
boxes_j = boxes[inds, j * 4: (j + 1) * 4]
ids_j = ids[inds]
det_idx = ids_j < 0
det_boxlist = BoxList(boxes_j[det_idx, :], boxlist.size, mode="xyxy")
det_boxlist.add_field("scores", scores_j[det_idx])
det_boxlist.add_field("ids", ids_j[det_idx])
det_boxlist = boxlist_nms(det_boxlist, self.nms)
track_idx = ids_j >= 0
# track_box is available
if torch.any(track_idx > 0):
track_boxlist = BoxList(boxes_j[track_idx, :], boxlist.size, mode="xyxy")
track_boxlist.add_field("scores", scores_j[track_idx])
track_boxlist.add_field("ids", ids_j[track_idx])
det_boxlist = cat_boxlist([det_boxlist, track_boxlist])
num_labels = len(det_boxlist)
det_boxlist.add_field(
"labels", torch.full((num_labels,), j, dtype=torch.int64, device=device)
)
result[j-1] = det_boxlist
result = cat_boxlist(result)
return result