in trainers/swag.py [0:0]
def test(models, writer, criterion, data_loader, epoch):
np.random.seed(args.seed)
n = args.num_models
# turn all the models into vectors
vecs = [
utils.sd_to_vector(models[i].state_dict()).clone() for i in range(n)
]
swa_vec = vecs[0]
for i in range(1, n):
swa_vec = swa_vec + vecs[i]
swa_vec = swa_vec / n
square_vec = vecs[0].pow(2)
for i in range(1, n):
square_vec = square_vec + vecs[i].pow(2)
square_vec = square_vec / n
swa_diag_mult = (
(1.0 / math.sqrt(2))
* (square_vec - swa_vec.pow(2)).pow(0.5)
* torch.randn_like(swa_vec)
)
low_rank_term = (vecs[0] - swa_vec) * torch.randn(1).item()
for i in range(1, n):
low_rank_term = (
low_rank_term + (vecs[i] - swa_vec) * torch.randn(1).item()
)
low_rank_term = (1.0 / math.sqrt(2 * (n - 1))) * low_rank_term
out = swa_vec + swa_diag_mult + low_rank_term
final_model_sd = models[0].state_dict()
utils.vector_to_sd(out, final_model_sd)
models[0].load_state_dict(final_model_sd)
utils.update_bn(data_loader.train_loader, models[0], device=args.device)
torch.save(
{
"epoch": 0,
"iter": 0,
"arch": args.model,
"state_dicts": [models[0].state_dict()],
"optimizers": None,
"best_acc1": 0,
"curr_acc1": 0,
},
os.path.join(args.tmp_dir, f"model_{args.j}.pt"),
)
test_acc = 0
metrics = {}
return test_acc, metrics