FlashInferAttnParamsPtr FlashInferAttnParams::prepare()

in maga_transformer/cpp/devices/cuda_impl/CudaFlashInfer.cc [300:408]


FlashInferAttnParamsPtr FlashInferAttnParams::prepare(
        rtp_llm::DeviceBase *device,
        const rtp_llm::AttentionConfigs &attn_configs,
        const BufferPtr &prefix_lengths_host,
        const BufferPtr &sequence_lengths_host,
        const BufferPtr &input_lengths_host,
        const BufferPtr &kv_cache_block_id_host,
        const BufferPtr &kv_cache_block_id_device,
        rtp_llm::DataType dtype)
{
    if (rtp_llm::get_sm() < 80) {
        return nullptr;
    }

    const int batch_size = input_lengths_host->shape()[0];
    if (batch_size == 0) {
        return nullptr;
    }

    auto cuda_device = dynamic_cast<CudaDevice*>(device);
    if (!cuda_device) {
        return nullptr;
    }

    MlaOpsType mla_ops_type = device->mla_ops_type;
    int q_length = -1;
    if (mla_ops_type == MlaOpsType::FLASH_MLA &&
        (!sameQLength(input_lengths_host, batch_size, q_length) || q_length == -1 || q_length > 32)) {
        mla_ops_type = MlaOpsType::FLASH_INFER;
    }

    const char* disable_flash_infer_env = getenv("DISABLE_FLASH_INFER");
    const bool disable_flash_infer (disable_flash_infer_env && strcmp(disable_flash_infer_env, "1") == 0);
    if ((!attn_configs.use_mla || mla_ops_type == MlaOpsType::FLASH_INFER) && disable_flash_infer) {
        return nullptr;
    }

    const int local_head_num    = attn_configs.head_num;
    const int local_head_num_kv = attn_configs.kv_head_num;
    const int size_per_head = attn_configs.size_per_head;
    const int group_size = local_head_num / local_head_num_kv;
    const int tokens_per_block = attn_configs.tokens_per_block;

    // to underlay buffer dtype
    if (dtype == DataType::TYPE_QFP8_E4M3) {
        dtype = DataType::TYPE_FP8_E4M3;
    }

    if (!attn_configs.use_mla) {
        if ((dtype != DataType::TYPE_FP16 && dtype != DataType::TYPE_BF16 && dtype != DataType::TYPE_FP8_E4M3) ||
            (attn_configs.kv_cache_dtype != KvCacheDataType::BASE &&
             !(attn_configs.kv_cache_dtype == KvCacheDataType::FP8 && dtype == DataType::TYPE_FP8_E4M3)) ||
            (attn_configs.rope_config.style != RopeStyle::Base && attn_configs.rope_config.style != RopeStyle::No)  ||
            attn_configs.mask_type != causalMask ||
            attn_configs.q_scaling != 1.0f ||
            attn_configs.use_logn_attn ||
            (size_per_head != 64 && size_per_head != 128 && size_per_head != 192) ||
            (group_size > 10 && group_size != 16))
        {
            return nullptr;
        }
    }

    int input_token_num = 0;
    if (prefix_lengths_host) {
        input_token_num = std::accumulate(input_lengths_host->data<int>(),
                                          input_lengths_host->data<int>() + batch_size,
                                          0);
    } else {
        input_token_num = input_lengths_host->shape()[0];
    }

    auto params = FlashInferAttnParams::create(cuda_device,
                                               max(MIN_CACHE_BATCH_SIZE, batch_size),
                                               max(MIN_CACHE_INPUT_TOKEN_NUM, input_token_num),
                                               MIN_CACHE_PAGE_NUM);
    FlashInferAttnParamsPtr ret(params, FlashInferAttnParamsDel);

    if (kv_cache_block_id_device) {
        params->kv_cache_block_id_d = Buffer2torchTensor(kv_cache_block_id_device, false);
    }
    params->mla_ops_type = mla_ops_type;
    params->dtype = dtype;
    params->fillFlashInfer(prefix_lengths_host,
                           sequence_lengths_host,
                           input_lengths_host,
                           kv_cache_block_id_host,
                           batch_size,
                           tokens_per_block);
    params->refreshFlashInferBuf(cuda_device, batch_size, input_token_num);

    if (group_size > 5) {
        params->decode = false;
    } else {
        params->decode = true;
    }

    params->genPlan(batch_size,
                    q_length,
                    local_head_num,
                    local_head_num_kv,
                    size_per_head,
                    tokens_per_block,
                    attn_configs.kv_lora_rank,
                    attn_configs.use_mla,
                    reinterpret_cast<int64_t>(cuda_device->getStream())); // cuda_stream

    return ret;
}