in sparse_autoencoder/kernels.py [0:0]
def triton_sum_dim0_in_fp32_kernel( xs_ptr, out_ptr, stride_a, a, b, BLOCK_SIZE_A: tl.constexpr, BLOCK_SIZE_B: tl.constexpr,