in sparse_autoencoder/kernels.py [0:0]
def triton_sum_dim0_in_fp32(xs):
a, b = xs.shape
assert xs.is_contiguous()
assert xs.dtype == torch.float16
BLOCK_SIZE_A = min(triton.next_power_of_2(a), 512)
BLOCK_SIZE_B = 64 # cache line is 128 bytes
out = torch.zeros(b, dtype=torch.float32, device=xs.device)
grid = lambda META: (triton.cdiv(b, META["BLOCK_SIZE_B"]),)
triton_sum_dim0_in_fp32_kernel[grid](
xs,
out,
stride_a=xs.stride(0),
a=a,
b=b,
BLOCK_SIZE_A=BLOCK_SIZE_A,
BLOCK_SIZE_B=BLOCK_SIZE_B,
)
return out