in sparse_autoencoder/kernels.py [0:0]
def triton_dense_dense_sparseout_matmul_kernel(
dense1_ptr,
dense2_ptr,
at_indices_ptr,
out_ptr,
stride_d1a,
stride_d1b,
stride_d2b,
stride_d2n,
A,
B,
N,
K,
BLOCK_SIZE_B: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,