def triton_add_mul_kernel()

in sparse_autoencoder/kernels.py [0:0]


def triton_add_mul_kernel(
    x_ptr,
    a_ptr,
    b_ptr,
    c,
    stride_x0,
    stride_x1,
    stride_a0,
    stride_a1,
    stride_b0,
    stride_b1,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    M: tl.constexpr,
    N: tl.constexpr,