in seamseg/algos/detection.py [0:0]
def __call__(self, boxes, scores):
"""Perform NMS-based selection of detections
Parameters
----------
boxes : sequence of torch.Tensor
Sequence of N tensors of class-specific bounding boxes with shapes M_i x C x 4, entries can be None
scores : sequence of torch.Tensor
Sequence of N tensors of class probabilities with shapes M_i x (C + 1), entries can be None
Returns
-------
bbx_pred : PackedSequence
A sequence of N tensors of bounding boxes with shapes S_i x 4, entries are None for images in which no
detection can be kept according to the selection parameters
cls_pred : PackedSequence
A sequence of N tensors of thing class predictions with shapes S_i, entries are None for images in which no
detection can be kept according to the selection parameters
obj_pred : PackedSequence
A sequence of N tensors of detection confidences with shapes S_i, entries are None for images in which no
detection can be kept according to the selection parameters
"""
bbx_pred, cls_pred, obj_pred = [], [], []
for bbx_i, obj_i in zip(boxes, scores):
try:
if bbx_i is None or obj_i is None:
raise Empty
# Do NMS separately for each class
bbx_pred_i, cls_pred_i, obj_pred_i = [], [], []
for cls_id, (bbx_cls_i, obj_cls_i) in enumerate(zip(torch.unbind(bbx_i, dim=1),
torch.unbind(obj_i, dim=1)[1:])):
# Filter out low-scoring predictions
idx = obj_cls_i > self.score_threshold
if not idx.any().item():
continue
bbx_cls_i = bbx_cls_i[idx]
obj_cls_i = obj_cls_i[idx]
# Filter out empty predictions
idx = (bbx_cls_i[:, 2] > bbx_cls_i[:, 0]) & (bbx_cls_i[:, 3] > bbx_cls_i[:, 1])
if not idx.any().item():
continue
bbx_cls_i = bbx_cls_i[idx]
obj_cls_i = obj_cls_i[idx]
# Do NMS
idx = nms(bbx_cls_i.contiguous(), obj_cls_i.contiguous(), threshold=self.nms_threshold, n_max=-1)
if idx.numel() == 0:
continue
bbx_cls_i = bbx_cls_i[idx]
obj_cls_i = obj_cls_i[idx]
# Save remaining outputs
bbx_pred_i.append(bbx_cls_i)
cls_pred_i.append(bbx_cls_i.new_full((bbx_cls_i.size(0),), cls_id, dtype=torch.long))
obj_pred_i.append(obj_cls_i)
# Compact predictions from the classes
if len(bbx_pred_i) == 0:
raise Empty
bbx_pred_i = torch.cat(bbx_pred_i, dim=0)
cls_pred_i = torch.cat(cls_pred_i, dim=0)
obj_pred_i = torch.cat(obj_pred_i, dim=0)
# Do post-NMS selection (if needed)
if bbx_pred_i.size(0) > self.max_predictions:
_, idx = obj_pred_i.topk(self.max_predictions)
bbx_pred_i = bbx_pred_i[idx]
cls_pred_i = cls_pred_i[idx]
obj_pred_i = obj_pred_i[idx]
# Save results
bbx_pred.append(bbx_pred_i)
cls_pred.append(cls_pred_i)
obj_pred.append(obj_pred_i)
except Empty:
bbx_pred.append(None)
cls_pred.append(None)
obj_pred.append(None)
return PackedSequence(bbx_pred), PackedSequence(cls_pred), PackedSequence(obj_pred)