in sparse_autoencoder/train.py [0:0]
def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
"""project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
assert autoencoder.decoder.weight.grad is not None
triton_add_mul_(
autoencoder.decoder.weight.grad,
torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad),
autoencoder.decoder.weight.data,
c=-1,
)