in student_specialization/recon_two_layer.py [0:0]
def run(cfg):
W1_t, W2_t, W1_s, W2_s = init(cfg)
Btt = W2_t @ W2_t.t()
# log.info(Btt)
X_eval = torch.randn(cfg.N_eval, cfg.d) * cfg.data_std
if cfg.theory_suggest_train:
X_train = []
for i in range(cfg.m):
data = torch.randn(math.ceil(cfg.N_train / (cfg.m * 3)), cfg.d + 1).double().cuda() * cfg.data_std
data[:, -1] = 1
# projected to teacher plane.
w = W1_t[:, i]
# In the plane now.
data = data - torch.ger(data @ w, w) / w.pow(2).sum()
data = data[:, :-1] / data[:, -1][:, None]
alpha = torch.rand(data.size(0)).double().cuda() * cfg.theory_suggest_sigma + cfg.theory_suggest_mean
data_plus = data + torch.ger(alpha, w[:-1])
data_minus = data - torch.ger(alpha, w[:-1])
X_train.extend([data, data_plus, data_minus])
# X_train.extend([data_plus, data_minus])
X_train = torch.cat(X_train, dim=0)
X_train /= X_train.norm(dim=1)[:,None]
X_train *= cfg.data_std * math.sqrt(cfg.d)
print(f"Use dataset from theory: N_train = {X_train.size(0)}")
cfg.N_train = X_train.size(0)
else:
X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std
X_train, X_eval = convert(X_train, X_eval)
t_norms = W2_t.norm(dim=1)
print(f"teacher norm: {t_norms}")
init_stat = dict(W1_t=W1_t.cpu(), W2_t=W2_t.cpu(), W1_s=W1_s.cpu(), W2_s=W2_s.cpu())
init_stat.update(after_epoch_eval(-1, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg))
stats = []
stats.append(init_stat)
train_set_sel = list(range(cfg.N_train))
lr = cfg.lr
for i in range(cfg.num_epoch):
W1_s_old = W1_s.clone()
W2_s_old = W2_s.clone()
if cfg.lr_reduction > 0 and i > 0 and (i % cfg.lr_reduction == 0):
lr = lr / 2
log.info(f"{i}: reducing learning rate: {lr}")
for j in range(cfg.num_iter_per_epoch):
if cfg.use_sgd:
sel = random.choices(train_set_sel, k=cfg.batchsize)
# Randomly picking a subset.
X = X_train[sel, :].clone()
else:
# Gradient descent.
X = X_train
# Teacher's output.
X_aug, h1_t, h1_ng_t, output_t = forward(X, W1_t, W2_t, nonlinear=cfg.nonlinear)
# Student's output.
X_aug, h1_s, h1_ng_s, output_s = forward(X, W1_s, W2_s, nonlinear=cfg.nonlinear)
# Backpropagation.
g2 = output_t - output_s
deltaW1_s, deltaW2_s, _ = backward(X_aug, W1_s, W2_s, h1_s, h1_ng_s, g2, nonlinear=cfg.nonlinear)
deltaW1_s /= X.size(0)
deltaW2_s /= X.size(0)
if not cfg.feature_fixed:
W1_s += lr * deltaW1_s
if cfg.normalize:
normalize(W1_s)
if not cfg.top_layer_fixed:
W2_s += lr * deltaW2_s
if cfg.no_bias:
W1_s[-1, :] = 0
W2_s[-1, :] = 0
stat = after_epoch_eval(i, X_train, X_eval, W1_t, W2_t, W1_s, W2_s, cfg)
stats.append(stat)
if cfg.regen_dataset:
X_train = torch.randn(cfg.N_train, cfg.d) * cfg.data_std
X_train = convert(X_train)[0]
log.info(f"|W1|={W1_s.norm()}, |W2|={W2_s.norm()}")
log.info(f"|deltaW1|={(W1_s - W1_s_old).norm()}, |deltaW2|={(W2_s - W2_s_old).norm()}")
return stats