in benchmark/bench_flash_mla.py [0:0]
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
kv_indptr = [0]
kv_indices = []
for i in range(b):
seq_len = cache_seqlens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_table[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
for seq_len in cache_seqlens[1:]:
kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
q_indptr = torch.arange(0, b + 1).int() * s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8),
backend="fa3"
)
mla_wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
cache_seqlens,
h_q,
dv,
d-dv,
block_size,
causal,
1 / math.sqrt(d),
q.dtype,
blocked_k.dtype,
)
def flash_infer():
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flash_infer()
t = triton.testing.do_bench(flash_infer)
return out_flash, lse_flash, t