def triton_sum_dim0_in_fp32_kernel()

in sparse_autoencoder/kernels.py [0:0]


def triton_sum_dim0_in_fp32_kernel(
    xs_ptr,
    out_ptr,
    stride_a,
    a,
    b,
    BLOCK_SIZE_A: tl.constexpr,
    BLOCK_SIZE_B: tl.constexpr,