def run()

in student_specialization/recon_two_layer.py [0:0]


def run(cfg):
    W1_t, W2_t, W1_s, W2_s  = init(cfg)

    Btt = W2_t @ W2_t.t()
    # log.info(Btt)

    X_eval = torch.randn(cfg.N_eval, cfg.d) * cfg.data_std

    if cfg.theory_suggest_train:
        X_train = []
        for i in range(cfg.m):
            data = torch.randn(math.ceil(cfg.N_train / (cfg.m * 3)), cfg.d + 1).double().cuda() * cfg.data_std
            data[:, -1] = 1
            # projected to teacher plane.
            w = W1_t[:, i] 
            # In the plane now. 
            data = data - torch.ger(data @ w, w) / w.pow(2).sum()
            data = data[:, :-1] / data[:, -1][:, None]

            alpha = torch.rand(data.size(0)).double().cuda() * cfg.theory_suggest_sigma + cfg.theory_suggest_mean
            data_plus = data + torch.ger(alpha, w[:-1]) 
            data_minus = data - torch.ger(alpha, w[:-1])

            X_train.extend([data, data_plus, data_minus])
            # X_train.extend([data_plus, data_minus])

        X_train = torch.cat(X_train, dim=0)
        X_train /= X_train.norm(dim=1)[:,None] 
        X_train *= cfg.data_std * math.sqrt(cfg.d)
        print(f"Use dataset from theory: N_train = {X_train.size(0)}") 
        cfg.N_train = X_train.size(0)
    else:
        X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std

    X_train, X_eval = convert(X_train, X_eval)

    t_norms = W2_t.norm(dim=1)
    print(f"teacher norm: {t_norms}")

    init_stat = dict(W1_t=W1_t.cpu(), W2_t=W2_t.cpu(), W1_s=W1_s.cpu(), W2_s=W2_s.cpu())
    init_stat.update(after_epoch_eval(-1, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg))

    stats = []
    stats.append(init_stat)

    train_set_sel = list(range(cfg.N_train))
    lr = cfg.lr

    for i in range(cfg.num_epoch):
        W1_s_old = W1_s.clone()
        W2_s_old = W2_s.clone()

        if cfg.lr_reduction > 0 and i > 0 and (i % cfg.lr_reduction == 0):
            lr = lr / 2
            log.info(f"{i}: reducing learning rate: {lr}")

        for j in range(cfg.num_iter_per_epoch):
            if cfg.use_sgd:
                sel = random.choices(train_set_sel, k=cfg.batchsize)
                # Randomly picking a subset.
                X = X_train[sel, :].clone()
            else:
                # Gradient descent. 
                X = X_train

            # Teacher's output. 
            X_aug, h1_t, h1_ng_t, output_t = forward(X, W1_t, W2_t, nonlinear=cfg.nonlinear)

            # Student's output.
            X_aug, h1_s, h1_ng_s, output_s = forward(X, W1_s, W2_s, nonlinear=cfg.nonlinear)

            # Backpropagation. 
            g2 = output_t - output_s
            deltaW1_s, deltaW2_s, _ = backward(X_aug, W1_s, W2_s, h1_s, h1_ng_s, g2, nonlinear=cfg.nonlinear)
            deltaW1_s /= X.size(0)
            deltaW2_s /= X.size(0)

            if not cfg.feature_fixed:
                W1_s += lr * deltaW1_s
                if cfg.normalize:
                    normalize(W1_s)

            if not cfg.top_layer_fixed:
                W2_s += lr * deltaW2_s

            if cfg.no_bias:
                W1_s[-1, :] = 0
                W2_s[-1, :] = 0


        stat = after_epoch_eval(i, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg)
        stats.append(stat)

        if cfg.regen_dataset:
            X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std
            X_train = convert(X_train)[0]

        log.info(f"|W1|={W1_s.norm()}, |W2|={W2_s.norm()}")
        log.info(f"|deltaW1|={(W1_s - W1_s_old).norm()}, |deltaW2|={(W2_s - W2_s_old).norm()}")

    return stats