in seamseg/utils/panoptic.py [0:0]
def __call__(self, sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred, num_stuff):
img_size = [sem_pred.size(0), sem_pred.size(1)]
# Initialize outputs
occupied = torch.zeros_like(sem_pred, dtype=torch.uint8)
msk = torch.zeros_like(sem_pred)
cat = [255]
obj = [0]
iscrowd = [0]
# Process thing
try:
if bbx_pred is None or cls_pred is None or obj_pred is None or msk_pred is None:
raise Empty
# Remove low-confidence instances
keep = obj_pred > self.score_threshold
if not keep.any():
raise Empty
obj_pred, bbx_pred, cls_pred, msk_pred = obj_pred[keep], bbx_pred[keep], cls_pred[keep], msk_pred[keep]
# Up-sample masks
bbx_inv = invert_roi_bbx(bbx_pred, list(msk_pred.shape[-2:]), img_size)
bbx_idx = torch.arange(0, msk_pred.size(0), dtype=torch.long, device=msk_pred.device)
msk_pred = roi_sampling(msk_pred.unsqueeze(1).sigmoid(), bbx_inv, bbx_idx, tuple(img_size), padding="zero")
msk_pred = msk_pred.squeeze(1) > 0.5
# Sort by score
idx = torch.argsort(obj_pred, descending=True)
# Process instances
for msk_i, cls_i, obj_i in zip(msk_pred[idx], cls_pred[idx], obj_pred[idx]):
# Check intersection
intersection = occupied & msk_i
if intersection.float().sum() / msk_i.float().sum() > self.overlap_threshold:
continue
# Add non-intersecting part to output
msk_i = msk_i - intersection
msk[msk_i] = len(cat)
cat.append(cls_i.item() + num_stuff)
obj.append(obj_i.item())
iscrowd.append(0)
# Update occupancy mask
occupied += msk_i
except Empty:
pass
# Process stuff
for cls_i in range(sem_pred.max().item() + 1):
msk_i = sem_pred == cls_i
# Remove occupied part and check remaining area
msk_i = msk_i & ~occupied
if msk_i.float().sum() < self.min_stuff_area:
continue
# Add non-intersecting part to output
msk[msk_i] = len(cat)
cat.append(cls_i)
obj.append(1)
iscrowd.append(cls_i >= num_stuff)
# Update occupancy mask
occupied += msk_i
# Wrap in tensors
cat = torch.tensor(cat, dtype=torch.long)
obj = torch.tensor(obj, dtype=torch.float)
iscrowd = torch.tensor(iscrowd, dtype=torch.uint8)
return msk.cpu(), cat, obj, iscrowd