in trainers/ensemble.py [0:0]
def test(models, writer, criterion, data_loader, epoch):
for model in models:
model.eval()
model.apply(lambda m: setattr(m, "return_feats", True))
test_loss = 0
correct = 0
tvdist_sum = 0
tvdist_len = 0
feat_cossim = 0
percent_disagree_sum = 0
percent_disagree_len = 0
percent_disagree_correct_sum = 0
percent_disagree_correct_len = 0
val_loader = data_loader.val_loader
if args.update_bn:
for model in models:
utils.update_bn(data_loader.train_loader, model, args.device)
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(args.device), target.to(args.device)
output, f = models[0](data)
probs = [nn.functional.softmax(output, dim=1)]
feats = [f]
for t in range(1, args.num_models):
modelt_output, model_feats_t = models[t](data)
feats.append(model_feats_t)
probs.append(nn.functional.softmax(modelt_output, dim=1))
output += modelt_output
# output = 0
# for p in probs:
# output += p.log()
# output = (output / args.num_models).exp()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
# get tvdist between i and j
for i in range(args.num_models):
for j in range(i + 1, args.num_models):
feat_cossim += (
nn.functional.cosine_similarity(
feats[i], feats[j], dim=1
)
.sum()
.item()
)
pairwise_tvdist = 0.5 * (probs[i] - probs[j]).abs().sum(
dim=1
)
tvdist_len += pairwise_tvdist.size(0)
tvdist_sum += pairwise_tvdist.sum().item()
model_i_pred = probs[i].argmax(dim=1, keepdim=True)
model_j_pred = probs[j].argmax(dim=1, keepdim=True)
percent_disagree_len += data.size(0)
percent_disagree_sum += (
(model_i_pred != model_j_pred).sum().item()
)
percent_disagree_correct_len += data.size(0)
percent_disagree_correct_sum += (
(
(model_i_pred != model_j_pred)
* (
model_i_pred.eq(target.view_as(model_i_pred))
+ model_j_pred.eq(target.view_as(model_j_pred))
)
)
.sum()
.item()
)
feat_cossim = feat_cossim / tvdist_len if tvdist_len > 0 else 0
tvdist = tvdist_sum / tvdist_len if tvdist_len > 0 else 0
percent_disagree = (
percent_disagree_sum / percent_disagree_len
if percent_disagree_len > 0
else 0
)
percent_disagree_correct = (
percent_disagree_correct_sum / percent_disagree_correct_len
if percent_disagree_correct_len > 0
else 0
)
test_loss /= len(val_loader)
test_acc = float(correct) / len(val_loader.dataset)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f}), TVDist: ({tvdist})\n"
)
if args.save:
writer.add_scalar(f"test/loss", test_loss, epoch)
writer.add_scalar(f"test/acc", test_acc, epoch)
metrics = {
"tvdist": tvdist,
"percent_disagree": percent_disagree,
"percent_disagree_correct": percent_disagree_correct,
"feat_cossim": feat_cossim,
}
return test_acc, metrics