in sparse_autoencoder/kernels.py [0:0]
def triton_add_mul_kernel(
x_ptr,
a_ptr,
b_ptr,
c,
stride_x0,
stride_x1,
stride_a0,
stride_a1,
stride_b0,
stride_b1,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
M: tl.constexpr,
N: tl.constexpr,