def mla_decode_triton()

in benchmark/bench_flash_mla.py [0:0]


def mla_decode_triton(
    q_nope,
    q_pe,
    kv_c_cache,
    k_pe_cache,
    o,
    req_to_tokens,
    b_seq_len,
    attn_logits,
    num_kv_splits,
    sm_scale,
    page_size,