in meshrcnn/utils/metrics.py [0:0]
def _compute_sampling_metrics(pred_points, pred_normals, gt_points, gt_normals, thresholds, eps):
"""
Compute metrics that are based on sampling points and normals:
- L2 Chamfer distance
- Precision at various thresholds
- Recall at various thresholds
- F1 score at various thresholds
- Normal consistency (if normals are provided)
- Absolute normal consistency (if normals are provided)
Inputs:
- pred_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
for each predicted mesh
- pred_normals: Tensor of shape (N, S, 3) giving normals of points sampled
from the predicted mesh, or None if such normals are not available
- gt_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
for each ground-truth mesh
- gt_normals: Tensor of shape (N, S, 3) giving normals of points sampled from
the ground-truth verts, or None of such normals are not available
- thresholds: Distance thresholds to use for precision / recall / F1
- eps: epsilon value to handle numerically unstable F1 computation
Returns:
- metrics: A dictionary where keys are metric names and values are Tensors of
shape (N,) giving the value of the metric for the batch
"""
metrics = {}
lengths_pred = torch.full(
(pred_points.shape[0],), pred_points.shape[1], dtype=torch.int64, device=pred_points.device
)
lengths_gt = torch.full(
(gt_points.shape[0],), gt_points.shape[1], dtype=torch.int64, device=gt_points.device
)
# For each predicted point, find its neareast-neighbor GT point
knn_pred = knn_points(pred_points, gt_points, lengths1=lengths_pred, lengths2=lengths_gt, K=1)
# Compute L1 and L2 distances between each pred point and its nearest GT
pred_to_gt_dists2 = knn_pred.dists[..., 0] # (N, S)
pred_to_gt_dists = pred_to_gt_dists2.sqrt() # (N, S)
if gt_normals is not None:
pred_normals_near = knn_gather(gt_normals, knn_pred.idx, lengths_gt)[..., 0, :] # (N, S, 3)
else:
pred_normals_near = None
# For each GT point, find its nearest-neighbor predicted point
knn_gt = knn_points(gt_points, pred_points, lengths1=lengths_gt, lengths2=lengths_pred, K=1)
# Compute L1 and L2 dists between each GT point and its nearest pred point
gt_to_pred_dists2 = knn_gt.dists[..., 0] # (N, S)
gt_to_pred_dists = gt_to_pred_dists2.sqrt() # (N, S)
if pred_normals is not None:
gt_normals_near = knn_gather(pred_normals, knn_gt.idx, lengths_pred)[..., 0, :] # (N, S, 3)
else:
gt_normals_near = None
# Compute L2 chamfer distances
chamfer_l2 = pred_to_gt_dists2.mean(dim=1) + gt_to_pred_dists2.mean(dim=1)
metrics["Chamfer-L2"] = chamfer_l2
# Compute normal consistency and absolute normal consistance only if
# we actually got normals for both meshes
if pred_normals is not None and gt_normals is not None:
pred_to_gt_cos = F.cosine_similarity(pred_normals, pred_normals_near, dim=2)
gt_to_pred_cos = F.cosine_similarity(gt_normals, gt_normals_near, dim=2)
pred_to_gt_cos_sim = pred_to_gt_cos.mean(dim=1)
pred_to_gt_abs_cos_sim = pred_to_gt_cos.abs().mean(dim=1)
gt_to_pred_cos_sim = gt_to_pred_cos.mean(dim=1)
gt_to_pred_abs_cos_sim = gt_to_pred_cos.abs().mean(dim=1)
normal_dist = 0.5 * (pred_to_gt_cos_sim + gt_to_pred_cos_sim)
abs_normal_dist = 0.5 * (pred_to_gt_abs_cos_sim + gt_to_pred_abs_cos_sim)
metrics["NormalConsistency"] = normal_dist
metrics["AbsNormalConsistency"] = abs_normal_dist
# Compute precision, recall, and F1 based on L2 distances
for t in thresholds:
precision = 100.0 * (pred_to_gt_dists < t).float().mean(dim=1)
recall = 100.0 * (gt_to_pred_dists < t).float().mean(dim=1)
f1 = (2.0 * precision * recall) / (precision + recall + eps)
metrics["Precision@%f" % t] = precision
metrics["Recall@%f" % t] = recall
metrics["F1@%f" % t] = f1
# Move all metrics to CPU
metrics = {k: v.cpu() for k, v in metrics.items()}
return metrics