def optimize()

in luckmatter/recon_multilayer.py [0:0]


def optimize(train_loader, eval_loader, teacher, student, loss_func, active_nodes, args):
    optimizer = optim.SGD(student.parameters(), lr = args.lr[0], momentum=args.momentum, weight_decay=args.weight_decay)
    # optimizer = optim.SGD(student.parameters(), lr = 1e-2, momentum=0.9)
    # optimizer = optim.Adam(student.parameters(), lr = 0.0001)

    stats = []

    last_total_diff = None
    print("Before optimization: ")

    if args.normalize:
        student.normalize()
    
    init_student = deepcopy(student)

    # Match response
    _, init_corrs_train, _, _ = getCorrs(train_loader, teacher, student, args)
    _, init_corrs_eval, _, _ = getCorrs(eval_loader, teacher, student, args)

    def add_prefix(prefix, d):
        return { prefix + k : v for k, v in d.items() }

    def get_stats(i):
        teacher.eval()
        student.eval()
        print("Train stats:")
        train_stats = add_prefix("train_", eval_models(i, train_loader, teacher, student, loss_func, args, init_corrs_train, init_student, active_nodes=active_nodes))
        print("Eval stats:")
        eval_stats = add_prefix("eval_", eval_models(i, eval_loader, teacher, student, loss_func, args, init_corrs_eval, init_student, active_nodes=active_nodes))

        train_stats.update(eval_stats)
        return train_stats

    stats.append(get_stats(-1))

    for i in range(args.num_epoch):
        teacher.eval()
        student.train()
        if i in args.lr:
            lr = args.lr[i]
            print(f"[{i}]: lr = {lr}")
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        # sample data from Gaussian distribution.
        # xsel = Variable(X.gather(0, sel))
        for x, y in train_loader:
            optimizer.zero_grad()
            if not args.use_cnn:
                x = x.view(x.size(0), -1)
            x = x.cuda()
            output_t = teacher(x)
            output_s = student(x)

            err = loss_func(output_s["y"], output_t["y"].detach())
            if torch.isnan(err).item():
                stats.append(dict(exit="nan"))
                return stats
            err.backward()
            optimizer.step()
            if args.normalize:
                student.normalize()

        stats.append(get_stats(i))
        if args.regen_dataset_each_epoch:
            train_loader.dataset.regenerate()

    print("After optimization: ")
    _, final_corrs, _, _ = getCorrs(eval_loader, teacher, student, args)

    result = compareCorrIndices(init_corrs_train, final_corrs)
    if args.json_output:
        print("json_output: " + json.dumps(result))
    print_corrs(result, active_nodes=active_nodes, first_n=5)

    return stats