def get_mla_metadata()

in flash_mla/flash_mla_interface.py [0:0]


def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,