def triton_sum_dim0_in_fp32()

in sparse_autoencoder/kernels.py [0:0]


def triton_sum_dim0_in_fp32(xs):
    a, b = xs.shape

    assert xs.is_contiguous()
    assert xs.dtype == torch.float16

    BLOCK_SIZE_A = min(triton.next_power_of_2(a), 512)
    BLOCK_SIZE_B = 64  # cache line is 128 bytes

    out = torch.zeros(b, dtype=torch.float32, device=xs.device)

    grid = lambda META: (triton.cdiv(b, META["BLOCK_SIZE_B"]),)

    triton_sum_dim0_in_fp32_kernel[grid](
        xs,
        out,
        stride_a=xs.stride(0),
        a=a,
        b=b,
        BLOCK_SIZE_A=BLOCK_SIZE_A,
        BLOCK_SIZE_B=BLOCK_SIZE_B,
    )

    return out