in student_specialization/recon_two_layer.py [0:0]
def init(cfg):
d = cfg.d
m = cfg.m
n = int(cfg.m * cfg.multi)
c = cfg.c
log.info(f"d = {d}, m = {m}, n = {n}, c = {c}")
W1_t = torch.randn(d + 1, m).cuda() * cfg.teacher_scale
W1_t[:-1, :] = torch.from_numpy(init_separate_w(m, d, cfg.choices)).t()
W1_t[-1, :] = cfg.bias
W2_t = torch.randn(m + 1, c).cuda() * cfg.teacher_scale
W2_t[:-1, :] = torch.from_numpy(init_separate_w(c, m, cfg.choices)).t()
if cfg.teacher_strength_decay > 0:
for i in range(1, m):
W2_t[i, :] /= pow(i + 1, cfg.teacher_strength_decay)
W2_t[-1, :] = cfg.bias
W1_s = torch.randn(d + 1, n).cuda() * cfg.student_scale
# Bias = 0
W1_s[-1, :] = 0
W2_s = torch.randn(n + 1, c).cuda() * cfg.student_scale
# Bias = 0
W2_s[-1, :] = 0
# delibrately move the weight away from the last teacher.
if cfg.adv_init == "adv":
for i in range(n):
if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() > 0:
W1_s[:-1, i] *= -1
elif cfg.adv_init == "help":
for i in range(n):
if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() < 0:
W1_s[:-1, i] *= -1
elif cfg.adv_init != "none":
raise RuntimeError(f"Invalid adv_init: {cfg.adv_init}")
W1_t, W2_t, W1_s, W2_s = convert(W1_t, W2_t, W1_s, W2_s)
normalize(W1_t)
if cfg.normalize:
normalize(W1_s)
return W1_t, W2_t, W1_s, W2_s