in maga_transformer/cpp/devices/cuda_impl/CudaDevice.cc [365:416]
DevicePrepOutput CudaDevice::prepareModelRun(const DevicePrepParams& params) {
DevicePrepOutput output;
fmha_type_ = FMHAType::NONE;
if (params.attn_dtype == DataType::TYPE_FP32) {
fmha_type_ = FMHAType::NONE;
output.need_mask = true;
} else if (params.context_batch_size) {
selectCuFMHARunner(params);
bool paged_kv_fmha =
params.diff_qkv_len && params.has_kv_cache && (params.configs.kv_cache_dtype == KvCacheDataType::BASE);
if (paged_kv_fmha) {
if (use_trtv2_fmha_paged && cufmha_runner_->trtV2FmhaPagedSupport()) {
fmha_type_ = FMHAType::PAGED_TRT_V2;
} else if (use_open_source_fmha_paged && cufmha_runner_->openSourceFmhaSupport()
&& params.configs.tokens_per_block % 256 == 0) {
fmha_type_ = FMHAType::PAGED_OPEN_SOURCE;
}
} else if (!params.diff_qkv_len) {
if (use_trtv2_fmha && cufmha_runner_->trtV2FmhaSupport()) {
fmha_type_ = FMHAType::TRT_V2;
} else if (use_open_source_fmha && cufmha_runner_->openSourceFmhaSupport()) {
fmha_type_ = FMHAType::OPEN_SOURCE;
} else if (use_trtv1_fmha && cufmha_runner_->trtV1FmhaSupport() && mla_ops_type == MlaOpsType::MHA) {
fmha_type_ = FMHAType::TRT_V1;
}
} else {
fmha_type_ = FMHAType::NONE;
}
output.need_mask = (fmha_type_ == FMHAType::NONE);
}
output.decode_flash_infer_attn_params = FlashInferAttnParams::prepare(
this,
params.configs,
nullptr,
params.sequence_lengths->slice(0, params.decoder_batch_size, false),
params.input_lengths->slice(0, params.decoder_batch_size, false),
params.kv_cache_block_id ? params.kv_cache_block_id->slice(0, params.decoder_batch_size, false) : nullptr,
params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(0, params.decoder_batch_size, false) : nullptr,
params.attn_dtype);
output.prefill_flash_infer_attn_params = FlashInferAttnParams::prepare(
this,
params.configs,
params.prefix_lengths,
nullptr,
params.input_lengths->slice(params.decoder_batch_size, params.context_batch_size, false),
params.kv_cache_block_id ? params.kv_cache_block_id->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
params.kv_cache_block_id_device ? params.kv_cache_block_id_device->slice(params.decoder_batch_size, params.context_batch_size, false) : nullptr,
params.attn_dtype);
return output;
}