in tensorflow_graphics/projects/points_to_3Dobjects/utils/evaluator.py [0:0]
def evaluate(self):
"""Eval."""
predictions_per_class = {} # map {classname: pred}
labels_per_class = {} # map {classname: gt}
for scene_id in self.predicted_boxes:
bboxes, classnames, scores = self.predicted_boxes[scene_id]
classnames = classnames.numpy()
bboxes = bboxes.numpy()
scores = scores.numpy()
for i in range(classnames.shape[0]):
classname = classnames[i]
bbox = bboxes[i]
score = scores[i]
# for classname, bbox, score in self.predicted_boxes[scene_id]:
if classname not in predictions_per_class:
predictions_per_class[classname] = {}
if scene_id not in predictions_per_class[classname]:
predictions_per_class[classname][scene_id] = []
if classname not in labels_per_class:
labels_per_class[classname] = {}
if scene_id not in labels_per_class[classname]:
labels_per_class[classname][scene_id] = []
predictions_per_class[classname][scene_id].append((bbox, score))
for scene_id in self.labeled_boxes:
bboxes, classnames = self.labeled_boxes[scene_id]
classnames = classnames.numpy()
bboxes = bboxes.numpy()
for i in range(classnames.shape[0]):
classname = classnames[i]
bbox = bboxes[i]
if classname not in labels_per_class:
labels_per_class[classname] = {}
if scene_id not in labels_per_class[classname]:
labels_per_class[classname][scene_id] = []
labels_per_class[classname][scene_id].append(bbox)
recall_per_class = {}
precision_per_class = {}
ap_per_class = {}
for classname in labels_per_class:
print('Computing AP for class: ', classname)
if classname in predictions_per_class:
recall, precision, ap = self._eval_detections_per_class(
# this does not work when class was never predicted
predictions_per_class[classname],
labels_per_class[classname],
self.threshold)
else:
recall, precision, ap = 0.0, 0.0, 0.0
recall_per_class[classname] = recall
precision_per_class[classname] = precision
ap_per_class[classname] = ap
print(classname, ap)
# return recall_per_class, precision_per_class, ap_per_class
mean = np.mean(np.array([v for k, v in ap_per_class.items()]))
print(mean)
return mean