def __call__()

in seamseg/algos/detection.py [0:0]


    def __call__(self, boxes, scores):
        """Perform NMS-based selection of detections

        Parameters
        ----------
        boxes : sequence of torch.Tensor
            Sequence of N tensors of class-specific bounding boxes with shapes M_i x C x 4, entries can be None
        scores : sequence of torch.Tensor
            Sequence of N tensors of class probabilities with shapes M_i x (C + 1), entries can be None

        Returns
        -------
        bbx_pred : PackedSequence
            A sequence of N tensors of bounding boxes with shapes S_i x 4, entries are None for images in which no
            detection can be kept according to the selection parameters
        cls_pred : PackedSequence
            A sequence of N tensors of thing class predictions with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        obj_pred : PackedSequence
            A sequence of N tensors of detection confidences with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        """
        bbx_pred, cls_pred, obj_pred = [], [], []
        for bbx_i, obj_i in zip(boxes, scores):
            try:
                if bbx_i is None or obj_i is None:
                    raise Empty

                # Do NMS separately for each class
                bbx_pred_i, cls_pred_i, obj_pred_i = [], [], []
                for cls_id, (bbx_cls_i, obj_cls_i) in enumerate(zip(torch.unbind(bbx_i, dim=1),
                                                                    torch.unbind(obj_i, dim=1)[1:])):
                    # Filter out low-scoring predictions
                    idx = obj_cls_i > self.score_threshold
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Filter out empty predictions
                    idx = (bbx_cls_i[:, 2] > bbx_cls_i[:, 0]) & (bbx_cls_i[:, 3] > bbx_cls_i[:, 1])
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Do NMS
                    idx = nms(bbx_cls_i.contiguous(), obj_cls_i.contiguous(), threshold=self.nms_threshold, n_max=-1)
                    if idx.numel() == 0:
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]

                    # Save remaining outputs
                    bbx_pred_i.append(bbx_cls_i)
                    cls_pred_i.append(bbx_cls_i.new_full((bbx_cls_i.size(0),), cls_id, dtype=torch.long))
                    obj_pred_i.append(obj_cls_i)

                # Compact predictions from the classes
                if len(bbx_pred_i) == 0:
                    raise Empty
                bbx_pred_i = torch.cat(bbx_pred_i, dim=0)
                cls_pred_i = torch.cat(cls_pred_i, dim=0)
                obj_pred_i = torch.cat(obj_pred_i, dim=0)

                # Do post-NMS selection (if needed)
                if bbx_pred_i.size(0) > self.max_predictions:
                    _, idx = obj_pred_i.topk(self.max_predictions)
                    bbx_pred_i = bbx_pred_i[idx]
                    cls_pred_i = cls_pred_i[idx]
                    obj_pred_i = obj_pred_i[idx]

                # Save results
                bbx_pred.append(bbx_pred_i)
                cls_pred.append(cls_pred_i)
                obj_pred.append(obj_pred_i)
            except Empty:
                bbx_pred.append(None)
                cls_pred.append(None)
                obj_pred.append(None)

        return PackedSequence(bbx_pred), PackedSequence(cls_pred), PackedSequence(obj_pred)