Dassl.pytorch/dassl/evaluation/evaluator.py (97 lines of code) (raw):
import numpy as np
import os.path as osp
from collections import OrderedDict, defaultdict
import torch
from sklearn.metrics import f1_score, confusion_matrix
from .build import EVALUATOR_REGISTRY
class EvaluatorBase:
"""Base evaluator."""
def __init__(self, cfg):
self.cfg = cfg
def reset(self):
raise NotImplementedError
def process(self, mo, gt):
raise NotImplementedError
def evaluate(self):
raise NotImplementedError
@EVALUATOR_REGISTRY.register()
class Classification(EvaluatorBase):
"""Evaluator for classification."""
def __init__(self, cfg, lab2cname=None, **kwargs):
super().__init__(cfg)
self._lab2cname = lab2cname
self._correct = 0
self._total = 0
self._per_class_res = None
self._y_true = []
self._y_pred = []
if cfg.TEST.PER_CLASS_RESULT:
assert lab2cname is not None
self._per_class_res = defaultdict(list)
def reset(self):
self._correct = 0
self._total = 0
self._y_true = []
self._y_pred = []
if self._per_class_res is not None:
self._per_class_res = defaultdict(list)
def process(self, mo, gt):
# mo (torch.Tensor): model output [batch, num_classes]
# gt (torch.LongTensor): ground truth [batch]
pred = mo.max(1)[1]
matches = pred.eq(gt).float()
self._correct += int(matches.sum().item())
self._total += gt.shape[0]
self._y_true.extend(gt.data.cpu().numpy().tolist())
self._y_pred.extend(pred.data.cpu().numpy().tolist())
if self._per_class_res is not None:
for i, label in enumerate(gt):
label = label.item()
matches_i = int(matches[i].item())
self._per_class_res[label].append(matches_i)
def evaluate(self):
results = OrderedDict()
acc = 100.0 * self._correct / self._total
err = 100.0 - acc
macro_f1 = 100.0 * f1_score(
self._y_true,
self._y_pred,
average="macro",
labels=np.unique(self._y_true)
)
# The first value will be returned by trainer.test()
results["accuracy"] = acc
results["error_rate"] = err
results["macro_f1"] = macro_f1
print(
"=> result\n"
f"* total: {self._total:,}\n"
f"* correct: {self._correct:,}\n"
f"* accuracy: {acc:.2f}%\n"
f"* error: {err:.2f}%\n"
f"* macro_f1: {macro_f1:.2f}%"
)
if self._per_class_res is not None:
labels = list(self._per_class_res.keys())
labels.sort()
print("=> per-class result")
accs = []
for label in labels:
classname = self._lab2cname[label]
res = self._per_class_res[label]
correct = sum(res)
total = len(res)
acc = 100.0 * correct / total
accs.append(acc)
print(
f"* class: {label} ({classname})\t"
f"total: {total:,}\t"
f"correct: {correct:,}\t"
f"acc: {acc:.1f}%"
)
mean_acc = np.mean(accs)
print(f"* average: {mean_acc:.1f}%")
results["perclass_accuracy"] = mean_acc
if self.cfg.TEST.COMPUTE_CMAT:
cmat = confusion_matrix(
self._y_true, self._y_pred, normalize="true"
)
save_path = osp.join(self.cfg.OUTPUT_DIR, "cmat.pt")
torch.save(cmat, save_path)
print(f"Confusion matrix is saved to {save_path}")
return results