in maga_transformer/cpp/devices/rocm_impl/ROCmAttentionOp.cc [77:305]
AttentionModuleOutput ROCmDevice::contextAttention(const AttentionModuleParams& params) {
auto datatype = params.input.type();
auto token_num = params.input.shape()[0];
auto batch_size = params.common.context_batch_size;
auto decoder_batch_size = params.common.decoder_batch_size;
auto seq_len = params.common.context_max_seq_len;
auto seq_len_with_prefix = seq_len + params.common.max_prefix_length;
// auto context_token_num = params.common.context_token_num;
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
auto q_output = allocateBuffer({params.input.type(),
{batch_size, head_num, seq_len, size_per_head},
AllocationType::DEVICE},
{"q_output"});
auto k_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"k_output"});
auto v_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"v_output"});
BufferPtr kv_cache_block_id = nullptr;
KVBlockArray kv_block_array;
PrefixPromptBatchWeightsParam prefix_prompt_param;
if (params.common.kv_cache) {
const auto max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
kv_cache_block_id = allocateBuffer({DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::DEVICE},
{"kv_cache_block_id"});
kv_block_array = getKVBlockArray(params, *kv_cache_block_id, batch_size, false, stream_);
prefix_prompt_param.kv_block_array = kv_block_array;
if (params.common.prefix_prompt_lengths) {
prefix_prompt_param.d_prefix_prompt_lengths = params.common.prefix_prompt_lengths->data<int>();
prefix_prompt_param.max_prefix_prompt_length = params.common.max_prefix_length;
prefix_prompt_param.count_length = 1;
}
}
printBufferData(*params.common.input_lengths, "input_lengths");
if (params.common.cu_seqlens) {
printBufferData(*params.common.cu_seqlens, "cu_seqlens");
printBufferData(*params.common.cu_kv_seqlens, "cu_kv_seqlens");
}
printBufferData(params.input, "fa_input");
// int8
float* scale_out_ptr = nullptr;
int int8_mode = 0;
if (prefix_prompt_param.max_prefix_prompt_length > 0) {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeLoadPrefixKVCache,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
batch_size,
seq_len,
head_num,
kv_head_num,
size_per_head,
scale_out_ptr,
int8_mode,
stream_
);
}
bool store_qkv = true;
bool store_q = true;
bool store_kv = true;
bool store_cache = params.common.kv_cache.has_value();
// if all condition satisfy, no need to do invokeAddFusedQKVBiasTranspose
bool skip_add_bias_transpose = (params.configs.rope_config.style == RopeStyle::No &&
!params.common.kv_cache &&
!params.configs.fuse_qkv_add_bias);
RTP_LLM_LOG_DEBUG("skip_add_bias_transpose: %d", skip_add_bias_transpose);
if (!skip_add_bias_transpose)
{
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeAddFusedQKVBiasTranspose,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
params.input.data(),
nullptr,
params.common.position_ids ? params.common.position_ids->dataWithOffset<int>(decoder_batch_size * params.configs.rope_config.index_factor): nullptr,
params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ? params.weights.qkv_weight->bias->data() : nullptr,
params.common.padding_offset->data<int>(),
params.common.cu_seqlens->data<int>(),
batch_size,
seq_len,
token_num,
head_num,
kv_head_num,
size_per_head,
params.configs.rope_config,
params.configs.use_logn_attn,
scale_out_ptr,
int8_mode,
false,
store_qkv,
store_q,
store_kv,
store_cache,
stream_
);
sync_check_cuda_error();
writeCacheStore(params);
}
fmha_runner_->setup(datatype,
params.configs.mask_type,
head_num,
kv_head_num,
size_per_head,
params.configs.q_scaling);
// auto seq_len_round_32 = (seq_len + 31) / 32 * 32;
// auto softmax_lse_ = allocateBuffer({DataType::TYPE_FP32, // params.output.type(),
// {batch_size, head_num, seq_len_round_32},
// AllocationType::DEVICE},
// {"softmax_lse"});
printBufferData(*q_output, "q_output");
// printBufferData(*k_output, "k_output");
// printBufferData(*v_output, "v_output");
// if (v_output->shape()[0]>1) {
// printBufferData(*(v_output->index(1)), "v_output_batch1");
// }
const size_t hidden_units = head_num * size_per_head;
const size_t hidden_units_kv = kv_head_num * size_per_head;
uint32_t lse_acc_buf_sz =
fmha_runner_->runCKFmha(params.input.data(),
params.input.dataWithOffset(hidden_units),
params.input.dataWithOffset(hidden_units + hidden_units_kv),
params.output.data(),
nullptr, // buffer for store out softmax_lse, looks like not used by RTP
batch_size,
seq_len,
// context_token_num,
params.common.cu_seqlens->data(),
params.common.cu_kv_seqlens->data(),
nullptr,
params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
nullptr);
auto lse_acc_buf = allocateBuffer({DataType::TYPE_FP32, {lse_acc_buf_sz}, AllocationType::DEVICE}, {"lse_acc_buf"});
if (fmha_runner_->runCKFmha(params.input.data(),
params.input.dataWithOffset(hidden_units),
params.input.dataWithOffset(hidden_units + hidden_units_kv),
params.output.data(),
nullptr, // buffer for store out softmax_lse, looks like not used by RTP
batch_size,
seq_len,
// context_token_num,
params.common.cu_seqlens->data(),
params.common.cu_kv_seqlens->data(),
lse_acc_buf->data(),
params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
nullptr)) {
return;
} else {
RTP_LLM_LOG_WARNING("ck fmha failed, falling to default implementation. This decreases performance drastically.");
auto qk_output = gemm({*q_output,
*k_output,
std::nullopt,
nullptr,
DataType::TYPE_FP32,
TransposeOperation::NONE,
TransposeOperation::TRANSPOSE});
printBufferData(*qk_output, "qk_output: ");
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
// TODO(lidongjin): Only support float32(in)\float16(output).
auto softmax_type = qk_output->type();
auto lengths_host = clone({
params.common.input_lengths->view(decoder_batch_size, batch_size),
AllocationType::HOST
});
auto prefix_lengths_host = params.common.prefix_prompt_lengths
? clone({*params.common.prefix_prompt_lengths, AllocationType::HOST})
: BufferPtr(new Buffer(MemoryType::MEMORY_CPU, DataType::TYPE_INVALID, {0}, nullptr));
auto attention_mask = attentionMask({
*lengths_host,
*prefix_lengths_host,
q_output->type(),
params.configs.mask_type == AttentionMaskType::causalMask
});
RUNTIME_ASSERT_OP_ARG(
params.common.attention_mask,
"attention_mask must be provided for default context attention implementation");
auto softmax_qk_output = softmax({std::move(qk_output),
*attention_mask,
nullopt,
scale,
datatype});
printBufferData(*softmax_qk_output, "softmax_qk_output: ");
auto qkv_output = gemm({*softmax_qk_output, *v_output});
auto &qkv_transpose_output = params.output;
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeTransposeQKV,
qkv_transpose_output.data(),
qkv_output->data(),
batch_size,
seq_len,
head_num,
size_per_head,
nullptr,
0,
stream_);
}
}