def __call__()

in seamseg/algos/rpn.py [0:0]


    def __call__(self, anchors, bbx, iscrowd, valid_size):
        """Match anchors to ground truth boxes

        Parameters
        ----------
        anchors : torch.Tensor
            Tensors of anchor bounding boxes with shapes M x 4
        bbx : sequence of torch.Tensor
            Sequence of N tensors of ground truth bounding boxes with shapes M_i x 4, entries can be None
        iscrowd : sequence of torch.Tensor
            Sequence of N tensors of ground truth crowd regions (shapes H_i x W_i), or ground truth crowd bounding boxes
            (shapes K_i x 4), entries can be None
        valid_size : list of tuple of int
            List of N valid image sizes in input coordinates

        Returns
        -------
        match : torch.Tensor
            Tensor of matching results with shape N x M, with the following semantic:
              - match[i, j] == -2: the j-th anchor in image i is void
              - match[i, j] == -1: the j-th anchor in image i is negative
              - match[i, j] == k, k >= 0: the j-th anchor in image i is matched to bbx[i][k]
        """
        match = []
        for bbx_i, iscrowd_i, valid_size_i in zip(bbx, iscrowd, valid_size):
            # Default labels: everything is void
            match_i = anchors.new_full((anchors.size(0),), -2, dtype=torch.long)

            try:
                # Find anchors that are entirely within the original image area
                valid = self._is_inside(anchors, valid_size_i)

                # Check overlap with crowd
                if self.void_threshold != 0 and iscrowd_i is not None:
                    if iscrowd_i.dtype == torch.uint8:
                        overlap = mask_overlap(anchors, iscrowd_i)
                    else:
                        overlap = bbx_overlap(anchors, iscrowd_i)
                        overlap, _ = overlap.max(dim=1)

                    valid = valid & (overlap < self.void_threshold)

                if not valid.any().item():
                    raise Empty

                valid_anchors = anchors[valid]

                if bbx_i is not None:
                    max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0))
                    max_a2g_idx = bbx_i.new_full((valid_anchors.size(0),), -1, dtype=torch.long)
                    max_g2a_iou = []
                    max_g2a_idx = []

                    # Calculate assignments iteratively to save memory
                    for j, bbx_i_j in enumerate(torch.split(bbx_i, CHUNK_SIZE, dim=0)):
                        iou = ious(valid_anchors, bbx_i_j)

                        # Anchor -> GT
                        iou_max, iou_idx = iou.max(dim=1)
                        replace_idx = iou_max > max_a2g_iou

                        max_a2g_idx[replace_idx] = iou_idx[replace_idx] + j * CHUNK_SIZE
                        max_a2g_iou[replace_idx] = iou_max[replace_idx]

                        # GT -> Anchor
                        max_g2a_iou_j, max_g2a_idx_j = iou.transpose(0, 1).max(dim=1)
                        max_g2a_iou.append(max_g2a_iou_j)
                        max_g2a_idx.append(max_g2a_idx_j)

                        del iou

                    max_g2a_iou = torch.cat(max_g2a_iou, dim=0)
                    max_g2a_idx = torch.cat(max_g2a_idx, dim=0)

                    a2g_pos = max_a2g_iou >= self.pos_threshold
                    a2g_neg = max_a2g_iou < self.neg_threshold
                    g2a_pos = max_g2a_iou > 0

                    valid_match = valid_anchors.new_full((valid_anchors.size(0),), -2, dtype=torch.long)
                    valid_match[a2g_pos] = max_a2g_idx[a2g_pos]
                    valid_match[a2g_neg] = -1
                    valid_match[max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze()
                else:
                    # No ground truth boxes for this image: everything that is not void is negative
                    valid_match = valid_anchors.new_full((valid_anchors.size(0),), -1, dtype=torch.long)

                # Subsample positives and negatives
                self._subsample(valid_match)

                match_i[valid] = valid_match
            except Empty:
                pass

            match.append(match_i)

        return torch.stack(match, dim=0)