def init()

in student_specialization/recon_two_layer.py [0:0]


def init(cfg):
    d = cfg.d
    m = cfg.m
    n = int(cfg.m * cfg.multi)
    c = cfg.c

    log.info(f"d = {d}, m = {m}, n = {n}, c = {c}")

    W1_t = torch.randn(d + 1, m).cuda() * cfg.teacher_scale
    W1_t[:-1, :] = torch.from_numpy(init_separate_w(m, d, cfg.choices)).t()
    W1_t[-1, :] = cfg.bias

    W2_t = torch.randn(m + 1, c).cuda() * cfg.teacher_scale
    W2_t[:-1, :] = torch.from_numpy(init_separate_w(c, m, cfg.choices)).t()

    if cfg.teacher_strength_decay > 0:
        for i in range(1, m):
            W2_t[i, :] /= pow(i + 1, cfg.teacher_strength_decay)

    W2_t[-1, :] = cfg.bias

    W1_s = torch.randn(d + 1, n).cuda() * cfg.student_scale
    # Bias = 0 
    W1_s[-1, :] = 0

    W2_s = torch.randn(n + 1, c).cuda() * cfg.student_scale
    # Bias = 0 
    W2_s[-1, :] = 0

    # delibrately move the weight away from the last teacher.
    if cfg.adv_init == "adv":
        for i in range(n):
            if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() > 0: 
                W1_s[:-1, i] *= -1
    elif cfg.adv_init == "help":
        for i in range(n):
            if (W1_t[:-1, -1] * W1_s[:-1, i]).sum().item() < 0: 
                W1_s[:-1, i] *= -1
    elif cfg.adv_init != "none":
        raise RuntimeError(f"Invalid adv_init: {cfg.adv_init}")

    W1_t, W2_t, W1_s, W2_s = convert(W1_t, W2_t, W1_s, W2_s)

    normalize(W1_t)

    if cfg.normalize:
        normalize(W1_s)

    return W1_t, W2_t, W1_s, W2_s