in student_specialization/recon_multilayer.py [0:0]
def optimize(train_loader, eval_loader, teacher, student, loss_func, train_stats_op, eval_stats_op, args, lrs):
if args.optim_method == "sgd":
optimizer = optim.SGD(student.parameters(), lr = lrs[0], momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optim_method == "adam":
optimizer = optim.Adam(student.parameters(), lr = lrs[0], weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Unknown optim method: {args.optim_method}")
# optimizer = optim.SGD(student.parameters(), lr = 1e-2, momentum=0.9)
# optimizer = optim.Adam(student.parameters(), lr = 0.0001)
stats = []
last_total_diff = None
log.info("Before optimization: ")
if args.normalize:
student.normalize()
init_student = deepcopy(student)
eval_stats = eval_model(-1, eval_loader, teacher, student, eval_stats_op)
eval_stats["iter"] = -1
stats.append(eval_stats)
for i in range(args.num_epoch):
if i in lrs:
lr = lrs[i]
log.info(f"[{i}]: lr = {lr}")
for param_group in optimizer.param_groups:
param_group['lr'] = lr
train_stats = train_model(i, train_loader, teacher, student, train_stats_op, loss_func, optimizer, args)
this_stats = dict(iter=i)
this_stats.update(train_stats)
if "exit" in train_stats:
stats.append(this_stats)
return stats
eval_stats = eval_model(i, eval_loader, teacher, student, eval_stats_op)
this_stats.update(eval_stats)
log.info(f"[{i}]: Bytesize of stats: {utils.count_size(this_stats) / 2 ** 20} MB")
stats.append(this_stats)
log.info("")
log.info("")
if args.regen_dataset_each_epoch:
train_loader.dataset.regenerate()
if args.num_epoch_save_summary > 0 and i % args.num_epoch_save_summary == 0:
# Only store starting and end stats.
end_stats = [ stats[0], stats[-1] ]
torch.save(end_stats, f"summary.pth")
# Save the summary at the end.
end_stats = [ stats[0], stats[-1] ]
torch.save(end_stats, f"summary.pth")
return stats