in sparse_autoencoder/kernels.py [0:0]
def backward(ctx, grad_output):
sparse_indices, sparse_values, decoder_weight = ctx.saved_tensors
assert grad_output.is_contiguous(), "grad_output must be contiguous; this is probably because the subsequent op was a .sum() or something like that, which returns a non contiguous gradient"
decoder_grad = triton_sparse_transpose_dense_matmul(
sparse_indices, sparse_values, grad_output, N=decoder_weight.shape[1]
).T
return (
None,
triton_dense_dense_sparseout_matmul(grad_output, decoder_weight, sparse_indices),
# decoder is contiguous when transposed so this is a matching layout
decoder_grad,
None,
)