def __call__()

in seamseg/utils/panoptic.py [0:0]


    def __call__(self, sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred, num_stuff):
        img_size = [sem_pred.size(0), sem_pred.size(1)]

        # Initialize outputs
        occupied = torch.zeros_like(sem_pred, dtype=torch.uint8)
        msk = torch.zeros_like(sem_pred)
        cat = [255]
        obj = [0]
        iscrowd = [0]

        # Process thing
        try:
            if bbx_pred is None or cls_pred is None or obj_pred is None or msk_pred is None:
                raise Empty

            # Remove low-confidence instances
            keep = obj_pred > self.score_threshold
            if not keep.any():
                raise Empty
            obj_pred, bbx_pred, cls_pred, msk_pred = obj_pred[keep], bbx_pred[keep], cls_pred[keep], msk_pred[keep]

            # Up-sample masks
            bbx_inv = invert_roi_bbx(bbx_pred, list(msk_pred.shape[-2:]), img_size)
            bbx_idx = torch.arange(0, msk_pred.size(0), dtype=torch.long, device=msk_pred.device)
            msk_pred = roi_sampling(msk_pred.unsqueeze(1).sigmoid(), bbx_inv, bbx_idx, tuple(img_size), padding="zero")
            msk_pred = msk_pred.squeeze(1) > 0.5

            # Sort by score
            idx = torch.argsort(obj_pred, descending=True)

            # Process instances
            for msk_i, cls_i, obj_i in zip(msk_pred[idx], cls_pred[idx], obj_pred[idx]):
                # Check intersection
                intersection = occupied & msk_i
                if intersection.float().sum() / msk_i.float().sum() > self.overlap_threshold:
                    continue

                # Add non-intersecting part to output
                msk_i = msk_i - intersection
                msk[msk_i] = len(cat)
                cat.append(cls_i.item() + num_stuff)
                obj.append(obj_i.item())
                iscrowd.append(0)

                # Update occupancy mask
                occupied += msk_i
        except Empty:
            pass

        # Process stuff
        for cls_i in range(sem_pred.max().item() + 1):
            msk_i = sem_pred == cls_i

            # Remove occupied part and check remaining area
            msk_i = msk_i & ~occupied
            if msk_i.float().sum() < self.min_stuff_area:
                continue

            # Add non-intersecting part to output
            msk[msk_i] = len(cat)
            cat.append(cls_i)
            obj.append(1)
            iscrowd.append(cls_i >= num_stuff)

            # Update occupancy mask
            occupied += msk_i

        # Wrap in tensors
        cat = torch.tensor(cat, dtype=torch.long)
        obj = torch.tensor(obj, dtype=torch.float)
        iscrowd = torch.tensor(iscrowd, dtype=torch.uint8)

        return msk.cpu(), cat, obj, iscrowd