in siammot/modelling/box_head/inference.py [0:0]
def forward(self, x, boxes):
"""
Arguments:
x (tuple[tensor, tensor]): x contains the class logits
and the box_regression from the model.
boxes (list[BoxList]): bounding boxes that are used as
reference, one for each image
Returns:
results (list[BoxList]): one BoxList for each image, containing
the extra fields labels and scores
"""
class_logits, box_regression = x
class_prob = F.softmax(class_logits, -1)
device = class_logits.device
# TODO think about a representation of batch of boxes
image_shapes = [box.size for box in boxes]
boxes_per_image = [len(box) for box in boxes]
concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)
if self.cls_agnostic_bbox_reg:
box_regression = box_regression[:, -4:]
proposals = self.box_coder.decode(
box_regression.view(sum(boxes_per_image), -1), concat_boxes
)
if self.cls_agnostic_bbox_reg:
proposals = proposals.repeat(1, class_prob.shape[1])
num_classes = class_prob.shape[1]
proposals = proposals.split(boxes_per_image, dim=0)
class_prob = class_prob.split(boxes_per_image, dim=0)
results = [self.create_empty_boxlist(device) for _ in boxes]
for i, (prob, boxes_per_img, image_shape, _box) in enumerate(zip(
class_prob, proposals, image_shapes, boxes)):
# get ids for each bbox
if _box.has_field('ids'):
ids = _box.get_field('ids')
else:
# deafult id is -1
ids = torch.zeros((len(_box),), dtype=torch.int64, device=device) - 1
# this only happens for tracks
if _box.has_field('labels'):
labels = _box.get_field('labels')
# tracks
track_inds = torch.squeeze(torch.nonzero(ids >= 0))
# avoid track bbs be suppressed during nms
if track_inds.numel() > 0:
prob_cp = prob.clone()
prob[track_inds, :] = 0.
prob[track_inds, labels] = prob_cp[track_inds, labels] + 1.
# # avoid track bbs be suppressed during nms
# prob[ids >= 0] = prob[ids >= 0] + 1.
boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, ids)
if not self.amodal_inference:
boxlist = boxlist.clip_to_image(remove_empty=False)
boxlist = self.filter_results(boxlist, num_classes)
results[i] = boxlist
return results