in trainers/eval_one_dim_subspaces_multigpu.py [0:0]
def test(models, writer, criterion, data_loader, epoch):
model = models[0]
model_0 = models[1]
model_0.eval()
model_0.zero_grad()
model.apply(lambda m: setattr(m, "return_feats", True))
model_0.apply(lambda m: setattr(m, "return_feats", True))
model.zero_grad()
model.eval()
test_loss = 0
correct = 0
ensemble_correct = 0
m0_correct = 0
tv_dist = 0.0
val_loader = data_loader.val_loader
feat_cosim = 0
model.apply(lambda m: setattr(m, "t", args.t))
model_0.apply(lambda m: setattr(m, "t", args.baset))
model.apply(lambda m: setattr(m, "t1", args.t))
model_0.apply(lambda m: setattr(m, "t1", args.baset))
if args.update_bn:
utils.update_bn(data_loader.train_loader, model, device=args.device)
utils.update_bn(data_loader.train_loader, model_0, device=args.device)
M = 20
acc_bm_m0 = torch.zeros(M)
conf_bm_m0 = torch.zeros(M)
count_bm_m0 = torch.zeros(M)
acc_bm_ens = torch.zeros(M)
conf_bm_ens = torch.zeros(M)
count_bm_ens = torch.zeros(M)
acc_bm = torch.zeros(M)
conf_bm = torch.zeros(M)
count_bm = torch.zeros(M)
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(args.device), target.to(args.device)
model.apply(lambda m: setattr(m, "t", args.t))
model.apply(lambda m: setattr(m, "t1", args.t))
output, feats = model(data)
test_loss += criterion(output, target).item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct_vec = pred.eq(target.view_as(pred))
correct += correct_vec.sum().item()
# get model 0
model_0.apply(lambda m: setattr(m, "t", args.baset))
model_0.apply(lambda m: setattr(m, "t1", args.baset))
model_0_output, model_0_feats = model_0(data)
ensemble_output = (model_0_output + output) / 2
ensemble_pred = ensemble_output.argmax(dim=1, keepdim=True)
ensemble_correct_vec = ensemble_pred.eq(target.view_as(pred))
ensemble_correct += ensemble_correct_vec.sum().item()
m0_pred = model_0_output.argmax(dim=1, keepdim=True)
m0_correct_vec = m0_pred.eq(target.view_as(pred))
m0_correct += m0_correct_vec.sum().item()
model_t_prob = nn.functional.softmax(output, dim=1)
model_0_prob = nn.functional.softmax(model_0_output, dim=1)
tv_dist += 0.5 * (model_0_prob - model_t_prob).abs().sum().item()
feat_cosim += (
torch.nn.functional.cosine_similarity(
feats, model_0_feats, dim=1
)
.pow(2)
.sum()
.item()
)
soft_out = output.softmax(dim=1)
soft_out_m0 = model_0_output.softmax(dim=1)
soft_out_ens = ensemble_output.softmax(dim=1)
# need to do ece for m0, ensemble, model
for i in range(data.size(0)):
conf = soft_out[i][pred[i]]
bin_idx = min((conf * M).int().item(), M - 1)
acc_bm[bin_idx] += correct_vec[i].float().item()
conf_bm[bin_idx] += conf.item()
count_bm[bin_idx] += 1.0
conf = soft_out_ens[i][pred[i]]
bin_idx = min((conf * M).int().item(), M - 1)
acc_bm_ens[bin_idx] += ensemble_correct_vec[i].float().item()
conf_bm_ens[bin_idx] += conf.item()
count_bm_ens[bin_idx] += 1.0
conf = soft_out_m0[i][pred[i]]
bin_idx = min((conf * M).int().item(), M - 1)
acc_bm_m0[bin_idx] += m0_correct_vec[i].float().item()
conf_bm_m0[bin_idx] += conf.item()
count_bm_m0[bin_idx] += 1.0
test_loss /= len(val_loader)
test_acc = float(correct) / len(val_loader.dataset)
m0_acc = float(m0_correct) / len(val_loader.dataset)
tv_dist /= len(val_loader.dataset)
feat_cosim /= len(val_loader.dataset)
ensemble_acc = float(ensemble_correct) / len(val_loader.dataset)
print(
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: ({test_acc:.4f})\n"
)
if args.save:
writer.add_scalar(f"test/loss", test_loss, epoch)
writer.add_scalar(f"test/acc", test_acc, epoch)
ece = 0.0
for i in range(M):
ece += (acc_bm[i] - conf_bm[i]).abs().item()
ece /= len(val_loader.dataset)
print("ece is", ece)
ece_ens = 0.0
for i in range(M):
ece_ens += (acc_bm_ens[i] - conf_bm_ens[i]).abs().item()
ece_ens /= len(val_loader.dataset)
print("ece_ens is", ece_ens)
ece_m0 = 0.0
for i in range(M):
ece_m0 += (acc_bm_m0[i] - conf_bm_m0[i]).abs().item()
ece_m0 /= len(val_loader.dataset)
print("ece_m0 is", ece_m0)
metrics = {
"ece": ece,
"ece_ens": ece_ens,
"ece_m0": ece_m0,
"tvdist": tv_dist,
"ensemble_acc": ensemble_acc,
"feat_cossim": feat_cosim,
"m0_acc": m0_acc,
}
return test_acc, metrics