def run_flash_mla_triton()

in benchmark/bench_flash_mla.py [0:0]


def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
    
    for i in range(b):
        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
    blocked_v = blocked_k[..., :dv]
    
    assert d > dv, "mla with rope dim should be larger than no rope dim"
    q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
    blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()

    def flash_mla_triton():
        num_kv_splits = 32
        o = torch.empty([b * s_q, h_q, dv])
        attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
        mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size)
        return o.view([b, s_q, h_q, dv])

    out_flash = flash_mla_triton()
    t = triton.testing.do_bench(flash_mla_triton)
    return out_flash, None, t