def unit_norm_decoder_grad_adjustment_()

in sparse_autoencoder/train.py [0:0]


def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
    """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""

    assert autoencoder.decoder.weight.grad is not None

    triton_add_mul_(
        autoencoder.decoder.weight.grad,
        torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad),
        autoencoder.decoder.weight.data,
        c=-1,
    )