in shapenet/evaluation/eval.py [0:0]
def evaluate_test(model, data_loader, vis_preds=False):
"""
This function evaluates the model on the dataset defined by data_loader.
The metrics reported are described in Table 2 of our paper.
"""
# Note that all eval runs on main process
assert comm.is_main_process()
deprocess = imagenet_deprocess(rescale_image=False)
device = torch.device("cuda:0")
# evaluation
class_names = {
"02828884": "bench",
"03001627": "chair",
"03636649": "lamp",
"03691459": "speaker",
"04090263": "firearm",
"04379243": "table",
"04530566": "watercraft",
"02691156": "plane",
"02933112": "cabinet",
"02958343": "car",
"03211117": "monitor",
"04256520": "couch",
"04401088": "cellphone",
}
num_instances = {i: 0 for i in class_names}
chamfer = {i: 0 for i in class_names}
normal = {i: 0 for i in class_names}
f1_01 = {i: 0 for i in class_names}
f1_03 = {i: 0 for i in class_names}
f1_05 = {i: 0 for i in class_names}
num_batch_evaluated = 0
for batch in data_loader:
batch = data_loader.postprocess(batch, device)
imgs, meshes_gt, _, _, _, id_strs = batch
sids = [id_str.split("-")[0] for id_str in id_strs]
for sid in sids:
num_instances[sid] += 1
with inference_context(model):
voxel_scores, meshes_pred = model(imgs)
cur_metrics = compare_meshes(meshes_pred[-1], meshes_gt, reduce=False)
cur_metrics["verts_per_mesh"] = meshes_pred[-1].num_verts_per_mesh().cpu()
cur_metrics["faces_per_mesh"] = meshes_pred[-1].num_faces_per_mesh().cpu()
for i, sid in enumerate(sids):
chamfer[sid] += cur_metrics["Chamfer-L2"][i].item()
normal[sid] += cur_metrics["AbsNormalConsistency"][i].item()
f1_01[sid] += cur_metrics["F1@%f" % 0.1][i].item()
f1_03[sid] += cur_metrics["F1@%f" % 0.3][i].item()
f1_05[sid] += cur_metrics["F1@%f" % 0.5][i].item()
if vis_preds:
img = image_to_numpy(deprocess(imgs[i]))
vis_utils.visualize_prediction(
id_strs[i], img, meshes_pred[-1][i], "/tmp/output"
)
num_batch_evaluated += 1
logger.info("Evaluated %d / %d batches" % (num_batch_evaluated, len(data_loader)))
vis_utils.print_instances_class_histogram(
num_instances,
class_names,
{"chamfer": chamfer, "normal": normal, "f1_01": f1_01, "f1_03": f1_03, "f1_05": f1_05},
)