in sparse_autoencoder/train.py [0:0]
def init_from_data_(ae, stats_acts_sample, comms):
from geom_median.torch import compute_geometric_median
ae.pre_bias.data = (
compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
)
comms.all_broadcast(ae.pre_bias.data)
# encoder initialization (note: in our ablations we couldn't find clear evidence that this is beneficial, this is just to ensure exact match with internal codebase)
d_model = ae.d_model
with torch.no_grad():
x = torch.randn(256, d_model).cuda().to(stats_acts_sample.dtype)
x /= x.norm(dim=-1, keepdim=True)
x += ae.pre_bias.data
comms.all_broadcast(x)
recons, _ = ae(x)
recons_norm = (recons - ae.pre_bias.data).norm(dim=-1).mean()
ae.encoder.weight.data /= recons_norm.item()
print0("x norm", x.norm(dim=-1).mean().item())
print0("out norm", (ae(x)[0] - ae.pre_bias.data).norm(dim=-1).mean().item())