in sparse_autoencoder/kernels.py [0:0]
def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
# only used for auxk
xs_mu = (
triton_sum_dim0_in_fp32(xs) / xs.shape[0]
if xs.dtype == torch.float16
else xs.mean(dim=0)
)
loss = mse(recon, xs) / mse(
xs_mu[None, :].broadcast_to(xs.shape), xs
)
return loss