def normalized_mse()

in sparse_autoencoder/kernels.py [0:0]


def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
    # only used for auxk
    xs_mu = (
        triton_sum_dim0_in_fp32(xs) / xs.shape[0]
        if xs.dtype == torch.float16
        else xs.mean(dim=0)
    )

    loss = mse(recon, xs) / mse(
        xs_mu[None, :].broadcast_to(xs.shape), xs
    )

    return loss