in luckmatter/recon_multilayer.py [0:0]
def optimize(train_loader, eval_loader, teacher, student, loss_func, active_nodes, args):
optimizer = optim.SGD(student.parameters(), lr = args.lr[0], momentum=args.momentum, weight_decay=args.weight_decay)
# optimizer = optim.SGD(student.parameters(), lr = 1e-2, momentum=0.9)
# optimizer = optim.Adam(student.parameters(), lr = 0.0001)
stats = []
last_total_diff = None
print("Before optimization: ")
if args.normalize:
student.normalize()
init_student = deepcopy(student)
# Match response
_, init_corrs_train, _, _ = getCorrs(train_loader, teacher, student, args)
_, init_corrs_eval, _, _ = getCorrs(eval_loader, teacher, student, args)
def add_prefix(prefix, d):
return { prefix + k : v for k, v in d.items() }
def get_stats(i):
teacher.eval()
student.eval()
print("Train stats:")
train_stats = add_prefix("train_", eval_models(i, train_loader, teacher, student, loss_func, args, init_corrs_train, init_student, active_nodes=active_nodes))
print("Eval stats:")
eval_stats = add_prefix("eval_", eval_models(i, eval_loader, teacher, student, loss_func, args, init_corrs_eval, init_student, active_nodes=active_nodes))
train_stats.update(eval_stats)
return train_stats
stats.append(get_stats(-1))
for i in range(args.num_epoch):
teacher.eval()
student.train()
if i in args.lr:
lr = args.lr[i]
print(f"[{i}]: lr = {lr}")
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# sample data from Gaussian distribution.
# xsel = Variable(X.gather(0, sel))
for x, y in train_loader:
optimizer.zero_grad()
if not args.use_cnn:
x = x.view(x.size(0), -1)
x = x.cuda()
output_t = teacher(x)
output_s = student(x)
err = loss_func(output_s["y"], output_t["y"].detach())
if torch.isnan(err).item():
stats.append(dict(exit="nan"))
return stats
err.backward()
optimizer.step()
if args.normalize:
student.normalize()
stats.append(get_stats(i))
if args.regen_dataset_each_epoch:
train_loader.dataset.regenerate()
print("After optimization: ")
_, final_corrs, _, _ = getCorrs(eval_loader, teacher, student, args)
result = compareCorrIndices(init_corrs_train, final_corrs)
if args.json_output:
print("json_output: " + json.dumps(result))
print_corrs(result, active_nodes=active_nodes, first_n=5)
return stats