def optimize()

in student_specialization/recon_multilayer.py [0:0]


def optimize(train_loader, eval_loader, teacher, student, loss_func, train_stats_op, eval_stats_op, args, lrs):
    if args.optim_method == "sgd":
        optimizer = optim.SGD(student.parameters(), lr = lrs[0], momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optim_method == "adam":
        optimizer = optim.Adam(student.parameters(), lr = lrs[0], weight_decay=args.weight_decay)
    else:
        raise RuntimeError(f"Unknown optim method: {args.optim_method}")

    # optimizer = optim.SGD(student.parameters(), lr = 1e-2, momentum=0.9)
    # optimizer = optim.Adam(student.parameters(), lr = 0.0001)

    stats = []

    last_total_diff = None
    log.info("Before optimization: ")

    if args.normalize:
        student.normalize()
    
    init_student = deepcopy(student)

    eval_stats = eval_model(-1, eval_loader, teacher, student, eval_stats_op)
    eval_stats["iter"] = -1
    stats.append(eval_stats)

    for i in range(args.num_epoch):
        if i in lrs:
            lr = lrs[i]
            log.info(f"[{i}]: lr = {lr}")
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        train_stats = train_model(i, train_loader, teacher, student, train_stats_op, loss_func, optimizer, args)

        this_stats = dict(iter=i)
        this_stats.update(train_stats)

        if "exit" in train_stats:
            stats.append(this_stats)
            return stats

        eval_stats = eval_model(i, eval_loader, teacher, student, eval_stats_op)

        this_stats.update(eval_stats)
        log.info(f"[{i}]: Bytesize of stats: {utils.count_size(this_stats) / 2 ** 20} MB")

        stats.append(this_stats)

        log.info("")
        log.info("")

        if args.regen_dataset_each_epoch:
            train_loader.dataset.regenerate()

        if args.num_epoch_save_summary > 0 and i % args.num_epoch_save_summary == 0:
            # Only store starting and end stats.
            end_stats = [ stats[0], stats[-1] ]
            torch.save(end_stats, f"summary.pth")

    # Save the summary at the end.
    end_stats = [ stats[0], stats[-1] ]
    torch.save(end_stats, f"summary.pth")

    return stats