in scripts/train_instance_seg.py [0:0]
def validate(model, dataloader, loss_weights, **varargs):
model.eval()
dataloader.batch_sampler.set_epoch(varargs["epoch"])
num_stuff = dataloader.dataset.num_stuff
loss_meter = AverageMeter(())
data_time_meter = AverageMeter(())
batch_time_meter = AverageMeter(())
# Accumulators for ap and panoptic computation
coco_struct = []
img_list = []
data_time = time.time()
for it, batch in enumerate(dataloader):
with torch.no_grad():
idxs = batch["idx"]
batch_sizes = [img.shape[-2:] for img in batch["img"]]
original_sizes = batch["size"]
# Upload batch
batch = {k: batch[k].cuda(device=varargs["device"], non_blocking=True) for k in NETWORK_INPUTS}
data_time_meter.update(torch.tensor(time.time() - data_time))
batch_time = time.time()
# Run network
losses, pred = model(**batch, do_loss=True, do_prediction=True)
losses = OrderedDict((k, v.mean()) for k, v in losses.items())
losses = all_reduce_losses(losses)
loss = sum(w * l for w, l in zip(loss_weights, losses.values()))
# Update meters
loss_meter.update(loss.cpu())
batch_time_meter.update(torch.tensor(time.time() - batch_time))
del loss, losses
# Accumulate COCO AP and panoptic predictions
for i, (bbx_pred, cls_pred, obj_pred, msk_pred) in enumerate(
zip(pred["bbx_pred"], pred["cls_pred"], pred["obj_pred"], pred["msk_pred"])):
# If there are no detections skip this image
if bbx_pred is None:
continue
# COCO AP
coco_struct += coco_ap.process_prediction(
bbx_pred, cls_pred + num_stuff, obj_pred, msk_pred, batch_sizes[i], idxs[i], original_sizes[i])
img_list.append(idxs[i])
del pred, batch
# Log batch
if varargs["summary"] is not None and (it + 1) % varargs["log_interval"] == 0:
logging.iteration(
None, "val", varargs["global_step"],
varargs["epoch"] + 1, varargs["num_epochs"],
it + 1, len(dataloader),
OrderedDict([
("loss", loss_meter),
("data_time", data_time_meter),
("batch_time", batch_time_meter)
])
)
data_time = time.time()
# Finalize AP computation
det_map, msk_map = coco_ap.summarize_mp(coco_struct, varargs["coco_gt"], img_list, varargs["log_dir"], True)
# Log results
log_info("Validation done")
if varargs["summary"] is not None:
logging.iteration(
varargs["summary"], "val", varargs["global_step"],
varargs["epoch"] + 1, varargs["num_epochs"],
len(dataloader), len(dataloader),
OrderedDict([
("loss", loss_meter.mean.item()),
("det_map", det_map),
("msk_map", msk_map),
("data_time", data_time_meter.mean.item()),
("batch_time", batch_time_meter.mean.item())
])
)
return msk_map