DevicePrepOutput CudaDevice::prepareModelRun()

in maga_transformer/cpp/devices/cuda_impl/CudaDevice.cc [365:416]


DevicePrepOutput CudaDevice::prepareModelRun(const DevicePrepParams& params) {
    DevicePrepOutput output;
    fmha_type_ = FMHAType::NONE;
    if (params.attn_dtype == DataType::TYPE_FP32) {
        fmha_type_       = FMHAType::NONE;
        output.need_mask = true;
    } else if (params.context_batch_size) {
        selectCuFMHARunner(params);
        bool paged_kv_fmha =
            params.diff_qkv_len && params.has_kv_cache && (params.configs.kv_cache_dtype == KvCacheDataType::BASE);
        if (paged_kv_fmha) {
            if (use_trtv2_fmha_paged && cufmha_runner_->trtV2FmhaPagedSupport()) {
                fmha_type_ = FMHAType::PAGED_TRT_V2;
            } else if (use_open_source_fmha_paged && cufmha_runner_->openSourceFmhaSupport()
                       && params.configs.tokens_per_block % 256 == 0) {
                fmha_type_ = FMHAType::PAGED_OPEN_SOURCE;
            }
        } else if (!params.diff_qkv_len) {
            if (use_trtv2_fmha && cufmha_runner_->trtV2FmhaSupport()) {
                fmha_type_ = FMHAType::TRT_V2;
            } else if (use_open_source_fmha && cufmha_runner_->openSourceFmhaSupport()) {
                fmha_type_ = FMHAType::OPEN_SOURCE;
            } else if (use_trtv1_fmha && cufmha_runner_->trtV1FmhaSupport() && mla_ops_type == MlaOpsType::MHA) {
                fmha_type_ = FMHAType::TRT_V1;
            }
        } else {
            fmha_type_ = FMHAType::NONE;
        }
        output.need_mask = (fmha_type_ == FMHAType::NONE);
    }

    output.decode_flash_infer_attn_params = FlashInferAttnParams::prepare(
            this,
            params.configs,
            nullptr,
            params.sequence_lengths->slice(0, params.decoder_batch_size, false),
            params.input_lengths->slice(0, params.decoder_batch_size, false),
            params.kv_cache_block_id ? params.kv_cache_block_id->slice(0, params.decoder_batch_size, false) : nullptr,
            params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(0, params.decoder_batch_size, false) : nullptr,
            params.attn_dtype);
    output.prefill_flash_infer_attn_params = FlashInferAttnParams::prepare(
            this,
            params.configs,
            params.prefix_lengths,
            nullptr,
            params.input_lengths->slice(params.decoder_batch_size, params.context_batch_size, false),
            params.kv_cache_block_id ? params.kv_cache_block_id->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
            params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
            params.attn_dtype);

    return output;
}