in scripts/train_panoptic.py [0:0]
def validate(model, dataloader, loss_weights, **varargs):
model.eval()
dataloader.batch_sampler.set_epoch(varargs["epoch"])
num_stuff = dataloader.dataset.num_stuff
num_classes = dataloader.dataset.num_categories
loss_meter = AverageMeter(())
data_time_meter = AverageMeter(())
batch_time_meter = AverageMeter(())
# Accumulators for ap, mIoU and panoptic computation
panoptic_buffer = torch.zeros(4, num_classes, dtype=torch.double)
conf_mat = torch.zeros(256, 256, dtype=torch.double)
coco_struct = []
img_list = []
data_time = time.time()
for it, batch in enumerate(dataloader):
with torch.no_grad():
idxs = batch["idx"]
batch_sizes = [img.shape[-2:] for img in batch["img"]]
original_sizes = batch["size"]
# Upload batch
batch = {k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_INPUTS}
assert all(msk.size(0) == 1 for msk in batch["msk"]), \
"Mask R-CNN + segmentation requires panoptic ground truth"
data_time_meter.update(torch.tensor(time.time() - data_time))
batch_time = time.time()
# Run network
losses, pred, conf = model(**batch, do_loss=True, do_prediction=True)
losses = OrderedDict((k, v.mean()) for k, v in losses.items())
losses = all_reduce_losses(losses)
loss = sum(w * l for w, l in zip(loss_weights, losses.values()))
if varargs["eval_mode"] == "separate":
# Directly accumulate confusion matrix from the network
conf_mat[:num_classes, :num_classes] += conf["sem_conf"].to(conf_mat)
# Update meters
loss_meter.update(loss.cpu())
batch_time_meter.update(torch.tensor(time.time() - batch_time))
del loss, losses, conf
# Accumulate COCO AP and panoptic predictions
for i, (sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred, msk_gt, cat_gt, iscrowd) in enumerate(zip(
pred["sem_pred"], pred["bbx_pred"], pred["cls_pred"], pred["obj_pred"], pred["msk_pred"],
batch["msk"], batch["cat"], batch["iscrowd"])):
msk_gt = msk_gt.squeeze(0)
sem_gt = cat_gt[msk_gt]
# Remove crowd from gt
cmap = msk_gt.new_zeros(cat_gt.numel())
cmap[~iscrowd] = torch.arange(0, (~iscrowd).long().sum().item(), dtype=cmap.dtype, device=cmap.device)
msk_gt = cmap[msk_gt]
cat_gt = cat_gt[~iscrowd]
# Compute panoptic output
panoptic_pred = varargs["make_panoptic"](sem_pred, bbx_pred, cls_pred, obj_pred, msk_pred, num_stuff)
# Panoptic evaluation
panoptic_buffer += torch.stack(
panoptic_stats(msk_gt, cat_gt, panoptic_pred, num_classes, num_stuff), dim=0)
if varargs["eval_mode"] == "panoptic":
# Calculate confusion matrix on panoptic output
sem_pred = panoptic_pred[1][panoptic_pred[0]]
conf_mat_i = confusion_matrix(sem_gt.cpu(), sem_pred)
conf_mat += conf_mat_i.to(conf_mat)
# Update coco AP from panoptic output
if varargs["eval_coco"] and ((panoptic_pred[1] >= num_stuff) & (panoptic_pred[1] != 255)).any():
coco_struct += coco_ap.process_panoptic_prediction(
panoptic_pred, num_stuff, idxs[i], batch_sizes[i], original_sizes[i])
img_list.append(idxs[i])
elif varargs["eval_mode"] == "separate":
# Update coco AP from detection output
if varargs["eval_coco"] and bbx_pred is not None:
coco_struct += coco_ap.process_prediction(
bbx_pred, cls_pred + num_stuff, obj_pred, msk_pred, batch_sizes[i], idxs[i],
original_sizes[i])
img_list.append(idxs[i])
del pred, batch
# Log batch
if varargs["summary"] is not None and (it + 1) % varargs["log_interval"] == 0:
logging.iteration(
None, "val", varargs["global_step"],
varargs["epoch"] + 1, varargs["num_epochs"],
it + 1, len(dataloader),
OrderedDict([
("loss", loss_meter),
("data_time", data_time_meter),
("batch_time", batch_time_meter)
])
)
data_time = time.time()
# Finalize mIoU computation
conf_mat = conf_mat.to(device=varargs["device"])
distributed.all_reduce(conf_mat, distributed.ReduceOp.SUM)
conf_mat = conf_mat.cpu()[:num_classes, :]
miou = conf_mat.diag() / (conf_mat.sum(dim=1) + conf_mat.sum(dim=0)[:num_classes] - conf_mat.diag())
# Finalize AP computation
if varargs["eval_coco"]:
det_map, msk_map = coco_ap.summarize_mp(coco_struct, varargs["coco_gt"], img_list, varargs["log_dir"], True)
# Finalize panoptic computation
panoptic_score, stuff_pq, thing_pq = get_panoptic_scores(panoptic_buffer, varargs["device"], num_stuff)
# Log results
log_info("Validation done")
if varargs["summary"] is not None:
metrics = OrderedDict()
metrics["loss"] = loss_meter.mean.item()
if varargs["eval_coco"]:
metrics["det_map"] = det_map
metrics["msk_map"] = msk_map
metrics["miou"] = miou.mean().item()
metrics["panoptic"] = panoptic_score
metrics["stuff_pq"] = stuff_pq
metrics["thing_pq"] = thing_pq
metrics["data_time"] = data_time_meter.mean.item()
metrics["batch_time"] = batch_time_meter.mean.item()
logging.iteration(
varargs["summary"], "val", varargs["global_step"],
varargs["epoch"] + 1, varargs["num_epochs"],
len(dataloader), len(dataloader), metrics
)
log_miou(miou, dataloader.dataset.categories)
return panoptic_score