in maga_transformer/cpp/devices/cuda_impl/CudaAttentionOp.cc [25:176]
AttentionModuleOutput CudaDevice::contextAttention(const AttentionModuleParams& params) {
auto datatype = params.input.type();
auto token_num = params.input.shape()[0];
auto batch_size = params.common.context_batch_size;
auto decoder_batch_size = params.common.decoder_batch_size;
auto seq_len = params.common.context_max_seq_len;
auto seq_len_with_prefix = seq_len + params.common.max_prefix_length;
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
auto q_output = allocateBuffer({params.input.type(),
{batch_size, head_num, seq_len, size_per_head},
AllocationType::DEVICE},
{"q_output"});
auto k_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"k_output"});
auto v_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"v_output"});
BufferPtr qkv_buf_fp8;
if (use_fp8_fmha_) {
qkv_buf_fp8 = allocateBuffer({DataType::TYPE_FP8_E4M3,
{batch_size, (head_num + kv_head_num * 2), seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"qkv_fp8_output"});
cudaMemsetAsync(qkv_buf_fp8->data(), 0, qkv_buf_fp8->sizeBytes(), stream_);
}
if (fmha_type_ == FMHAType::NONE) {
cudaMemsetAsync(q_output->data(), 0, q_output->sizeBytes(), stream_);
cudaMemsetAsync(k_output->data(), 0, k_output->sizeBytes(), stream_);
cudaMemsetAsync(v_output->data(), 0, v_output->sizeBytes(), stream_);
}
BufferPtr kv_cache_block_id = nullptr;
BufferPtr kv_cache_offset_host = nullptr;
KVBlockArray kv_block_array;
PrefixPromptBatchWeightsParam prefix_prompt_param;
if (params.common.kv_cache) {
const auto max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
kv_cache_block_id = allocateBuffer({DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::DEVICE},
{"kv_cache_block_id"});
kv_block_array = getKVBlockArray(params, *kv_cache_block_id, batch_size, use_fp8_fmha_);
if (is_sm90() && fmha_type_ == FMHAType::PAGED_TRT_V2) {
kv_cache_offset_host = allocateBuffer({DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::HOST},
{"kv_cache_offset_host"});
this->copy({*kv_cache_offset_host, *kv_cache_block_id});
kv_block_array.pagedKVBlockOffsetsOnHost = kv_cache_offset_host->data();
}
prefix_prompt_param.kv_block_array = kv_block_array;
if (params.common.prefix_prompt_lengths) {
prefix_prompt_param.d_prefix_prompt_lengths = params.common.prefix_prompt_lengths->data<int>();
prefix_prompt_param.max_prefix_prompt_length = params.common.max_prefix_length;
prefix_prompt_param.count_length = 1;
}
}
if (fmha_type_ == FMHAType::NONE && prefix_prompt_param.max_prefix_prompt_length > 0) {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype,
invokeLoadPrefixKVCache,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
batch_size,
seq_len,
head_num,
kv_head_num,
size_per_head,
nullptr, // scale_out_ptr,
0, //int8_mode,
stream_);
sync_check_cuda_error();
}
// if all condition satisfy, no need to do invokeAddFusedQKVBiasTranspose
bool skip_add_bias_transpose = (params.configs.rope_config.style == RopeStyle::No && !params.common.kv_cache
&& !params.configs.fuse_qkv_add_bias && fmha_type_ != FMHAType::NONE);
RTP_LLM_LOG_DEBUG("skip_add_bias_transpose: %d", skip_add_bias_transpose);
if (!skip_add_bias_transpose) {
bool store_qkv = fmha_type_ != FMHAType::PAGED_TRT_V2 && fmha_type_ != FMHAType::NONE;
bool store_q = fmha_type_ == FMHAType::PAGED_TRT_V2 || fmha_type_ == FMHAType::NONE;
bool store_kv = fmha_type_ == FMHAType::NONE;
// if use mla cache, no need to store cache
bool store_cache = params.common.kv_cache.has_value();
DISPATCH_CUDA_FUNCTION_DATA_TYPE(
datatype,
invokeAddFusedQKVBiasTranspose,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
params.input.data(),
qkv_buf_fp8 != nullptr ? qkv_buf_fp8->data() : nullptr,
params.common.position_ids ? params.common.position_ids->dataWithOffset<int>(decoder_batch_size * params.configs.rope_config.index_factor): nullptr,
params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ? params.weights.qkv_weight->bias->data() : nullptr,
params.common.padding_offset->data<int>(),
params.common.cu_seqlens->data<int>(),
batch_size,
seq_len,
token_num,
head_num,
kv_head_num,
size_per_head,
params.configs.rope_config,
params.configs.use_logn_attn,
nullptr, // scale_out_ptr,
0, // int8_mode,
fmha_type_ == FMHAType::PAGED_TRT_V2,
store_qkv,
store_q,
store_kv,
store_cache,
stream_);
sync_check_cuda_error();
if (!qkv_buf_fp8) {
printBufferData(params.input, "after invoke transpse");
} else {
printBufferData(params.input, "after invoke transpse");
RTP_LLM_LOG_DEBUG("now print qkv_buf_fp8");
printBufferData(*qkv_buf_fp8.get(), "after invoke transpse fp8");
}
sync_check_cuda_error();
if (store_cache) {
writeCacheStore(params);
}
// printBufferData(params.input, "after invoke transpse");
printBufferData(*q_output, "Q after invoke transpose");
printBufferData(*k_output, "K after invoke transpose");
printBufferData(*v_output, "V after invoke transpose");
}
computeInsertedMoE();
prefillAttention(params, kv_block_array, q_output, k_output, v_output, qkv_buf_fp8);
}