def forward()

in siammot/modelling/box_head/inference.py [0:0]


    def forward(self, x, boxes):
        """
        Arguments:
            x (tuple[tensor, tensor]): x contains the class logits
                and the box_regression from the model.
            boxes (list[BoxList]): bounding boxes that are used as
                reference, one for each image

        Returns:
            results (list[BoxList]): one BoxList for each image, containing
                the extra fields labels and scores
        """
        class_logits, box_regression = x
        class_prob = F.softmax(class_logits, -1)
        device = class_logits.device

        # TODO think about a representation of batch of boxes
        image_shapes = [box.size for box in boxes]
        boxes_per_image = [len(box) for box in boxes]
        concat_boxes = torch.cat([a.bbox for a in boxes], dim=0)

        if self.cls_agnostic_bbox_reg:
            box_regression = box_regression[:, -4:]
        proposals = self.box_coder.decode(
            box_regression.view(sum(boxes_per_image), -1), concat_boxes
        )
        if self.cls_agnostic_bbox_reg:
            proposals = proposals.repeat(1, class_prob.shape[1])

        num_classes = class_prob.shape[1]

        proposals = proposals.split(boxes_per_image, dim=0)
        class_prob = class_prob.split(boxes_per_image, dim=0)

        results = [self.create_empty_boxlist(device) for _ in boxes]

        for i, (prob, boxes_per_img, image_shape, _box) in enumerate(zip(
                class_prob, proposals, image_shapes, boxes)):

            # get ids for each bbox
            if _box.has_field('ids'):
                ids = _box.get_field('ids')
            else:
                # deafult id is -1
                ids = torch.zeros((len(_box),), dtype=torch.int64, device=device) - 1

            # this only happens for tracks
            if _box.has_field('labels'):
                labels = _box.get_field('labels')

                # tracks
                track_inds = torch.squeeze(torch.nonzero(ids >= 0))

                # avoid track bbs be suppressed during nms
                if track_inds.numel() > 0:
                    prob_cp = prob.clone()
                    prob[track_inds, :] = 0.
                    prob[track_inds, labels] = prob_cp[track_inds, labels] + 1.

                # # avoid track bbs be suppressed during nms
            # prob[ids >= 0] = prob[ids >= 0] + 1.

            boxlist = self.prepare_boxlist(boxes_per_img, prob, image_shape, ids)
            if not self.amodal_inference:
                boxlist = boxlist.clip_to_image(remove_empty=False)
            boxlist = self.filter_results(boxlist, num_classes)

            results[i] = boxlist
        return results