maga_transformer/cpp/devices/rocm_impl/ROCmAttentionOp.cc (385 lines of code) (raw):
#include "maga_transformer/cpp/devices/rocm_impl/ROCmDevice.h"
#include "maga_transformer/cpp/devices/CommonDefines.h"
#include "maga_transformer/cpp/cuda/Dispatch.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
#include "maga_transformer/cpp/kernels/unfused_attention_kernels.h"
#include "maga_transformer/cpp/kernels/gpt_kernels.h"
#include "maga_transformer/cpp/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention.h"
using namespace std;
namespace rtp_llm {
KVBlockArray getKVBlockArray(const AttentionModuleParams& params,
const Buffer& kv_cache_offset_pointers,
int batch_size,
bool use_fp8_fmha,
cudaStream_t stream) {
const auto& kv_cache = params.common.kv_cache;
const auto& kv_blocks_offset = *(kv_cache->kv_cache_block_id);
const auto& kv_block_offset = (kv_cache->k_cache_buffer)->shape()[0] * kv_cache->layer_num;
RUNTIME_ASSERT_OP_ARG(kv_blocks_offset.shape()[0] == batch_size,
"context attention kv blocks batch size expected [%d] but buffer[%s]",
(int)batch_size,
kv_blocks_offset.debugString().c_str());
const auto max_blocks_per_batch = kv_blocks_offset.shape()[1];
const auto& k_cache = *(kv_cache->k_cache_buffer);
const auto& v_cache = *(kv_cache->v_cache_buffer);
auto const elemSize = kv_cache->k_scale_buffer || use_fp8_fmha ? sizeof(int8_t) : 2; // 2 for kv cache fp16
// RTP_LLM_LOG_INFO("kv_cache[0].typeSize():%d", kv_cache[0].typeSize());
RTP_LLM_LOG_DEBUG(
"kv_blocks_offset size:%d, k_cache:%p, v_cache:%p, k_cache[0].sizeBytes():%d, params.configs.tokens_per_block:%d, kv_block_offset:%d",
kv_blocks_offset.size(),
(uint64_t*)k_cache.data(),
(uint64_t)v_cache.data(),
k_cache[0].sizeBytes(),
params.configs.tokens_per_block,
kv_block_offset);
auto const sizePerToken = params.configs.kv_head_num * params.configs.size_per_head * elemSize;
KVBlockArray kv_cache_buffer =
KVBlockArray(batch_size,
max_blocks_per_batch,
params.configs.tokens_per_block,
sizePerToken,
0,
0,
(uint64_t*)k_cache.data(),
nullptr,
(rtp_llm::KVBlockArrayForContextFMHA::DataType*)kv_cache_offset_pointers.data());
invokeConvertOffsetToBlockArrayData((int32_t*)kv_cache_offset_pointers.data(),
(int*)kv_blocks_offset.data(),
batch_size,
max_blocks_per_batch,
kv_block_offset,
stream);
sync_check_cuda_error();
if (kv_cache->k_scale_buffer) {
RUNTIME_ASSERT_OP_ARG(kv_cache->v_scale_buffer,
"v scale buffer should has value when use k scale buffer has value");
const auto& k_scale = *(kv_cache->k_scale_buffer);
kv_cache_buffer.scale = k_scale.data();
kv_cache_buffer.mScaleBytesPerBlock = k_scale[0].sizeBytes();
}
KvCacheDataType cache_type = KvCacheDataType::BASE;
#ifdef ENABLE_FP8
if (use_fp8_fmha) {
cache_type = KvCacheDataType::FP8;
} else
#endif
if (kv_cache->k_scale_buffer && params.configs.kv_cache_dtype == KvCacheDataType::INT8) {
RTP_LLM_LOG_DEBUG("now use kv_cache int8");
cache_type = KvCacheDataType::INT8;
}
kv_cache_buffer.cache_type = cache_type;
sync_check_cuda_error();
return kv_cache_buffer;
}
AttentionModuleOutput ROCmDevice::contextAttention(const AttentionModuleParams& params) {
auto datatype = params.input.type();
auto token_num = params.input.shape()[0];
auto batch_size = params.common.context_batch_size;
auto decoder_batch_size = params.common.decoder_batch_size;
auto seq_len = params.common.context_max_seq_len;
auto seq_len_with_prefix = seq_len + params.common.max_prefix_length;
// auto context_token_num = params.common.context_token_num;
auto head_num = params.configs.head_num;
auto kv_head_num = params.configs.kv_head_num;
auto size_per_head = params.configs.size_per_head;
auto q_output = allocateBuffer({params.input.type(),
{batch_size, head_num, seq_len, size_per_head},
AllocationType::DEVICE},
{"q_output"});
auto k_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"k_output"});
auto v_output = allocateBuffer({params.input.type(),
{batch_size, kv_head_num, seq_len_with_prefix, size_per_head},
AllocationType::DEVICE},
{"v_output"});
BufferPtr kv_cache_block_id = nullptr;
KVBlockArray kv_block_array;
PrefixPromptBatchWeightsParam prefix_prompt_param;
if (params.common.kv_cache) {
const auto max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
kv_cache_block_id = allocateBuffer({DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::DEVICE},
{"kv_cache_block_id"});
kv_block_array = getKVBlockArray(params, *kv_cache_block_id, batch_size, false, stream_);
prefix_prompt_param.kv_block_array = kv_block_array;
if (params.common.prefix_prompt_lengths) {
prefix_prompt_param.d_prefix_prompt_lengths = params.common.prefix_prompt_lengths->data<int>();
prefix_prompt_param.max_prefix_prompt_length = params.common.max_prefix_length;
prefix_prompt_param.count_length = 1;
}
}
printBufferData(*params.common.input_lengths, "input_lengths");
if (params.common.cu_seqlens) {
printBufferData(*params.common.cu_seqlens, "cu_seqlens");
printBufferData(*params.common.cu_kv_seqlens, "cu_kv_seqlens");
}
printBufferData(params.input, "fa_input");
// int8
float* scale_out_ptr = nullptr;
int int8_mode = 0;
if (prefix_prompt_param.max_prefix_prompt_length > 0) {
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeLoadPrefixKVCache,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
batch_size,
seq_len,
head_num,
kv_head_num,
size_per_head,
scale_out_ptr,
int8_mode,
stream_
);
}
bool store_qkv = true;
bool store_q = true;
bool store_kv = true;
bool store_cache = params.common.kv_cache.has_value();
// if all condition satisfy, no need to do invokeAddFusedQKVBiasTranspose
bool skip_add_bias_transpose = (params.configs.rope_config.style == RopeStyle::No &&
!params.common.kv_cache &&
!params.configs.fuse_qkv_add_bias);
RTP_LLM_LOG_DEBUG("skip_add_bias_transpose: %d", skip_add_bias_transpose);
if (!skip_add_bias_transpose)
{
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeAddFusedQKVBiasTranspose,
q_output->data(),
k_output->data(),
v_output->data(),
&prefix_prompt_param,
params.input.data(),
nullptr,
params.common.position_ids ? params.common.position_ids->dataWithOffset<int>(decoder_batch_size * params.configs.rope_config.index_factor): nullptr,
params.configs.fuse_qkv_add_bias && params.weights.qkv_weight->bias ? params.weights.qkv_weight->bias->data() : nullptr,
params.common.padding_offset->data<int>(),
params.common.cu_seqlens->data<int>(),
batch_size,
seq_len,
token_num,
head_num,
kv_head_num,
size_per_head,
params.configs.rope_config,
params.configs.use_logn_attn,
scale_out_ptr,
int8_mode,
false,
store_qkv,
store_q,
store_kv,
store_cache,
stream_
);
sync_check_cuda_error();
writeCacheStore(params);
}
fmha_runner_->setup(datatype,
params.configs.mask_type,
head_num,
kv_head_num,
size_per_head,
params.configs.q_scaling);
// auto seq_len_round_32 = (seq_len + 31) / 32 * 32;
// auto softmax_lse_ = allocateBuffer({DataType::TYPE_FP32, // params.output.type(),
// {batch_size, head_num, seq_len_round_32},
// AllocationType::DEVICE},
// {"softmax_lse"});
printBufferData(*q_output, "q_output");
// printBufferData(*k_output, "k_output");
// printBufferData(*v_output, "v_output");
// if (v_output->shape()[0]>1) {
// printBufferData(*(v_output->index(1)), "v_output_batch1");
// }
const size_t hidden_units = head_num * size_per_head;
const size_t hidden_units_kv = kv_head_num * size_per_head;
uint32_t lse_acc_buf_sz =
fmha_runner_->runCKFmha(params.input.data(),
params.input.dataWithOffset(hidden_units),
params.input.dataWithOffset(hidden_units + hidden_units_kv),
params.output.data(),
nullptr, // buffer for store out softmax_lse, looks like not used by RTP
batch_size,
seq_len,
// context_token_num,
params.common.cu_seqlens->data(),
params.common.cu_kv_seqlens->data(),
nullptr,
params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
nullptr);
auto lse_acc_buf = allocateBuffer({DataType::TYPE_FP32, {lse_acc_buf_sz}, AllocationType::DEVICE}, {"lse_acc_buf"});
if (fmha_runner_->runCKFmha(params.input.data(),
params.input.dataWithOffset(hidden_units),
params.input.dataWithOffset(hidden_units + hidden_units_kv),
params.output.data(),
nullptr, // buffer for store out softmax_lse, looks like not used by RTP
batch_size,
seq_len,
// context_token_num,
params.common.cu_seqlens->data(),
params.common.cu_kv_seqlens->data(),
lse_acc_buf->data(),
params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data() : nullptr,
nullptr)) {
return;
} else {
RTP_LLM_LOG_WARNING("ck fmha failed, falling to default implementation. This decreases performance drastically.");
auto qk_output = gemm({*q_output,
*k_output,
std::nullopt,
nullptr,
DataType::TYPE_FP32,
TransposeOperation::NONE,
TransposeOperation::TRANSPOSE});
printBufferData(*qk_output, "qk_output: ");
float scale = (1.0f / sqrtf(size_per_head * 1.0f));
// TODO(lidongjin): Only support float32(in)\float16(output).
auto softmax_type = qk_output->type();
auto lengths_host = clone({
params.common.input_lengths->view(decoder_batch_size, batch_size),
AllocationType::HOST
});
auto prefix_lengths_host = params.common.prefix_prompt_lengths
? clone({*params.common.prefix_prompt_lengths, AllocationType::HOST})
: BufferPtr(new Buffer(MemoryType::MEMORY_CPU, DataType::TYPE_INVALID, {0}, nullptr));
auto attention_mask = attentionMask({
*lengths_host,
*prefix_lengths_host,
q_output->type(),
params.configs.mask_type == AttentionMaskType::causalMask
});
RUNTIME_ASSERT_OP_ARG(
params.common.attention_mask,
"attention_mask must be provided for default context attention implementation");
auto softmax_qk_output = softmax({std::move(qk_output),
*attention_mask,
nullopt,
scale,
datatype});
printBufferData(*softmax_qk_output, "softmax_qk_output: ");
auto qkv_output = gemm({*softmax_qk_output, *v_output});
auto &qkv_transpose_output = params.output;
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype, invokeTransposeQKV,
qkv_transpose_output.data(),
qkv_output->data(),
batch_size,
seq_len,
head_num,
size_per_head,
nullptr,
0,
stream_);
}
}
template<typename T>
void selfAttentionwrapper(const AttentionModuleParams params,
bool use_multi_block_mode,
size_t max_seq_len_tile,
void* partial_out,
float* partial_sum,
float* partial_max,
int* block_counter,
KVBlockArray kv_block_array,
cudaStream_t stream) {
size_t token_num = params.input.shape()[0];
size_t batch_size = params.common.decoder_batch_size;
size_t step = params.common.decoder_max_seq_len + 1;
size_t local_head_num = params.configs.head_num;
size_t local_head_num_kv = params.configs.kv_head_num;
size_t size_per_head = params.configs.size_per_head;
const auto& output = params.output;
const T* qkv_buf_ptr = params.input.data<T>();
T* qkv_buf_2_ = output.data<T>();
const T* bias_ptr = (params.weights.qkv_weight->bias == nullptr) ?
nullptr :
params.weights.qkv_weight->bias->data<T>();
// TODO(lidongjin) support relative attention
const T* relative_attention_bias_ptr = nullptr;
// prefix prompt
auto prefix_lengths = params.common.prefix_prompt_lengths ? params.common.prefix_prompt_lengths->data<int>() : nullptr;
auto max_prefix_length = params.common.max_prefix_length;
const auto* input_lengths = params.common.input_lengths->data<int>();
const auto* sequence_lengths = params.common.sequence_lengths->data<int>();
float q_scaling = params.configs.q_scaling;
int relative_attention_bias_stride = 0;
const float* linear_bias_slopes = params.common.linear_bias_slopes ? params.common.linear_bias_slopes->data<float>() : nullptr;
const bool* masked_tokens = nullptr;
// TODO(lidongjin) support int8
const float* query_weight_scale_out = nullptr;
const float* attention_output_weight_scale_out = nullptr;
int int8_mode = 0;
tensorrt_llm::common::QuantMode kv_cache_quant_mode = trt_common::QuantMode::fromDescription(false, false, false, false, false, false, false, false);
if (params.configs.kv_cache_dtype == KvCacheDataType::INT8) {
kv_cache_quant_mode = trt_common::QuantMode::fromDescription(true, true, false, false, false, true, false, true);
}
fusedQKV_masked_attention_dispatch<T, KVBlockArray>(
qkv_buf_ptr,
bias_ptr,
relative_attention_bias_ptr,
nullptr, // cache_indir
reinterpret_cast<T*>(qkv_buf_2_),
nullptr, // finished
sequence_lengths,
batch_size,
1, // beam_width
local_head_num,
local_head_num_kv,
size_per_head,
params.configs.rope_config,
params.configs.use_logn_attn,
nullptr,
step,
prefix_lengths,
max_prefix_length,
true, //count_prefix_lengths,
input_lengths,
step,
q_scaling,
relative_attention_bias_stride,
linear_bias_slopes,
masked_tokens,
query_weight_scale_out,
attention_output_weight_scale_out,
int8_mode,
kv_cache_quant_mode,
use_multi_block_mode,
(int)max_seq_len_tile,
reinterpret_cast<T*>(partial_out),
partial_sum,
partial_max,
block_counter,
params.configs.softmax_extra_scale,
kv_block_array,
stream);
sync_check_cuda_error();
}
AttentionModuleOutput ROCmDevice::decoderSelfAttention(const AttentionModuleParams& params) {
auto datatype = params.input.type();
size_t max_seq_len_tile = 0;
BufferPtr partial_out = nullptr;
BufferPtr partial_sum = nullptr;
BufferPtr partial_max = nullptr;
BufferPtr block_counter = nullptr;
size_t batch_size = params.common.decoder_batch_size;
size_t local_head_num = params.configs.head_num;
size_t size_per_head = params.configs.size_per_head;
if (use_multi_block_mode) {
const int threads_per_value = pow2roundup(size_per_head) * getTypeSize(datatype) / 16;
// for allocate partial output results memory. Regardless to THDS_PER_BLOCK
max_seq_len_tile = 256 / threads_per_value;
partial_out = allocateBuffer(
{datatype, {batch_size, max_seq_len_tile, local_head_num, size_per_head}, AllocationType::DEVICE},
{"partial_out"});
partial_sum = allocateBuffer(
{DataType::TYPE_FP32, {batch_size, max_seq_len_tile, local_head_num}, AllocationType::DEVICE},
{"partial_sum"});
partial_max = allocateBuffer(
{DataType::TYPE_FP32, {batch_size, max_seq_len_tile, local_head_num}, AllocationType::DEVICE},
{"partial_max"});
block_counter = allocateBuffer({DataType::TYPE_INT32, {batch_size, local_head_num}, AllocationType::DEVICE},
{"block_counter"});
// TODO(lidongjin) use fill op to set zeros.
cudaMemsetAsync(block_counter->data(), 0, sizeof(int) * batch_size * local_head_num, stream_);
}
void* partial_out_data = (partial_out == nullptr) ? nullptr : partial_out->data();
float* partial_sum_data = (partial_sum == nullptr) ? nullptr : partial_sum->data<float>();
float* partial_max_data = (partial_max == nullptr) ? nullptr : partial_max->data<float>();
int* block_counter_data = (block_counter == nullptr) ? nullptr : block_counter->data<int>();
RUNTIME_ASSERT_OP_ARG(params.common.kv_cache, "kv cache can not be null for decoder self-attention");
const auto max_blocks_per_batch = params.common.kv_cache->kv_cache_block_id->shape()[1];
auto kv_cache_offset = allocateBuffer(
{DataType::TYPE_INT32, {batch_size, 1, 2, max_blocks_per_batch}, AllocationType::DEVICE}, {"kv_cache_offset"});
KVBlockArray kv_block_array = getKVBlockArray(params, *kv_cache_offset, batch_size, false, stream_);
DISPATCH_CUDA_FUNCTION_DATA_TYPE(datatype,
selfAttentionwrapper,
params,
use_multi_block_mode,
max_seq_len_tile,
partial_out_data,
partial_sum_data,
partial_max_data,
block_counter_data,
kv_block_array,
stream_);
}
} // namespace rtp_llm