def init_from_data_()

in sparse_autoencoder/train.py [0:0]


def init_from_data_(ae, stats_acts_sample, comms):
    from geom_median.torch import compute_geometric_median

    ae.pre_bias.data = (
        compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
    )
    comms.all_broadcast(ae.pre_bias.data)

    # encoder initialization (note: in our ablations we couldn't find clear evidence that this is beneficial, this is just to ensure exact match with internal codebase)
    d_model = ae.d_model
    with torch.no_grad():
        x = torch.randn(256, d_model).cuda().to(stats_acts_sample.dtype)
        x /= x.norm(dim=-1, keepdim=True)
        x += ae.pre_bias.data
        comms.all_broadcast(x)
        recons, _ = ae(x)
        recons_norm = (recons - ae.pre_bias.data).norm(dim=-1).mean()

        ae.encoder.weight.data /= recons_norm.item()
        print0("x norm", x.norm(dim=-1).mean().item())
        print0("out norm", (ae(x)[0] - ae.pre_bias.data).norm(dim=-1).mean().item())