in maga_transformer/cpp/devices/cuda_impl/CudaFlashInfer.cc [300:408]
FlashInferAttnParamsPtr FlashInferAttnParams::prepare(
rtp_llm::DeviceBase *device,
const rtp_llm::AttentionConfigs &attn_configs,
const BufferPtr &prefix_lengths_host,
const BufferPtr &sequence_lengths_host,
const BufferPtr &input_lengths_host,
const BufferPtr &kv_cache_block_id_host,
const BufferPtr &kv_cache_block_id_device,
rtp_llm::DataType dtype)
{
if (rtp_llm::get_sm() < 80) {
return nullptr;
}
const int batch_size = input_lengths_host->shape()[0];
if (batch_size == 0) {
return nullptr;
}
auto cuda_device = dynamic_cast<CudaDevice*>(device);
if (!cuda_device) {
return nullptr;
}
MlaOpsType mla_ops_type = device->mla_ops_type;
int q_length = -1;
if (mla_ops_type == MlaOpsType::FLASH_MLA &&
(!sameQLength(input_lengths_host, batch_size, q_length) || q_length == -1 || q_length > 32)) {
mla_ops_type = MlaOpsType::FLASH_INFER;
}
const char* disable_flash_infer_env = getenv("DISABLE_FLASH_INFER");
const bool disable_flash_infer (disable_flash_infer_env && strcmp(disable_flash_infer_env, "1") == 0);
if ((!attn_configs.use_mla || mla_ops_type == MlaOpsType::FLASH_INFER) && disable_flash_infer) {
return nullptr;
}
const int local_head_num = attn_configs.head_num;
const int local_head_num_kv = attn_configs.kv_head_num;
const int size_per_head = attn_configs.size_per_head;
const int group_size = local_head_num / local_head_num_kv;
const int tokens_per_block = attn_configs.tokens_per_block;
// to underlay buffer dtype
if (dtype == DataType::TYPE_QFP8_E4M3) {
dtype = DataType::TYPE_FP8_E4M3;
}
if (!attn_configs.use_mla) {
if ((dtype != DataType::TYPE_FP16 && dtype != DataType::TYPE_BF16 && dtype != DataType::TYPE_FP8_E4M3) ||
(attn_configs.kv_cache_dtype != KvCacheDataType::BASE &&
!(attn_configs.kv_cache_dtype == KvCacheDataType::FP8 && dtype == DataType::TYPE_FP8_E4M3)) ||
(attn_configs.rope_config.style != RopeStyle::Base && attn_configs.rope_config.style != RopeStyle::No) ||
attn_configs.mask_type != causalMask ||
attn_configs.q_scaling != 1.0f ||
attn_configs.use_logn_attn ||
(size_per_head != 64 && size_per_head != 128 && size_per_head != 192) ||
(group_size > 10 && group_size != 16))
{
return nullptr;
}
}
int input_token_num = 0;
if (prefix_lengths_host) {
input_token_num = std::accumulate(input_lengths_host->data<int>(),
input_lengths_host->data<int>() + batch_size,
0);
} else {
input_token_num = input_lengths_host->shape()[0];
}
auto params = FlashInferAttnParams::create(cuda_device,
max(MIN_CACHE_BATCH_SIZE, batch_size),
max(MIN_CACHE_INPUT_TOKEN_NUM, input_token_num),
MIN_CACHE_PAGE_NUM);
FlashInferAttnParamsPtr ret(params, FlashInferAttnParamsDel);
if (kv_cache_block_id_device) {
params->kv_cache_block_id_d = Buffer2torchTensor(kv_cache_block_id_device, false);
}
params->mla_ops_type = mla_ops_type;
params->dtype = dtype;
params->fillFlashInfer(prefix_lengths_host,
sequence_lengths_host,
input_lengths_host,
kv_cache_block_id_host,
batch_size,
tokens_per_block);
params->refreshFlashInferBuf(cuda_device, batch_size, input_token_num);
if (group_size > 5) {
params->decode = false;
} else {
params->decode = true;
}
params->genPlan(batch_size,
q_length,
local_head_num,
local_head_num_kv,
size_per_head,
tokens_per_block,
attn_configs.kv_lora_rank,
attn_configs.use_mla,
reinterpret_cast<int64_t>(cuda_device->getStream())); // cuda_stream
return ret;
}