in maga_transformer/cpp/devices/base_impl/AttentionLayer.cc [10:223]
AttentionLayerOutput DeviceBase::attentionLayer(const AttentionLayerParams& params) {
const auto &input = params.input;
const auto &input_lengths = *params.common.input_lengths;
const auto &sequence_lengths = *params.common.sequence_lengths;
const auto &output_weight = params.weights.output_weight;
const auto generate_batch_size = sequence_lengths.shape()[0];
const auto context_batch_size = input_lengths.shape()[0] - generate_batch_size;
const auto context_token_num = params.common.context_token_num;
const auto pad_token_num = params.enable_sp ? params.pad_token_num : (context_token_num + generate_batch_size);
RUNTIME_ASSERT_OP_ARG(!params.residual, "default attention layer impl does not support residual!");
const auto& layer_kv_cache = params.common.kv_cache;
if (layer_kv_cache) {
const auto &kv_cache = layer_kv_cache.value();
const auto &kv_cache_block_id = *kv_cache.kv_cache_block_id;
const auto &shape = kv_cache.kv_cache_block_id->shape();
RUNTIME_ASSERT_OP_ARG(
((shape.size() == 2) && (shape[0] == input_lengths.shape()[0])),
"kv_cache_block_id shape in attention layer should be [batch_size, block_length]"
", but got %s", kv_cache_block_id.debugString().c_str());
RUNTIME_ASSERT_OP_ARG(
kv_cache.k_cache_buffer && kv_cache.v_cache_buffer,
"kv cache buffer should has value when use kv_cache_block_id");
const auto& k_cache_shape = kv_cache.k_cache_buffer->shape();
const auto& v_cache_shape = kv_cache.v_cache_buffer->shape();
RUNTIME_ASSERT_OP_ARG(
((k_cache_shape.size() == 4) && (v_cache_shape.size() == 4) && \
(k_cache_shape[0] == v_cache_shape[0]) && (k_cache_shape[1] == v_cache_shape[1]) && \
(k_cache_shape[2] == v_cache_shape[2]) && (k_cache_shape[3] == v_cache_shape[3]) && \
(k_cache_shape[1] == params.configs.kv_head_num) && \
(k_cache_shape[2] == params.configs.tokens_per_block) && \
(k_cache_shape[3] == params.configs.size_per_head)),
"kv cache buffer check shape failed. k_cache_buffer: %s, v_cache_buffer: %s",
kv_cache.k_cache_buffer->debugString().c_str(),
kv_cache.v_cache_buffer->debugString().c_str());
if (kv_cache.k_scale_buffer) {
const auto& k_scale_shape = kv_cache.k_scale_buffer->shape();
const auto& v_scale_shape = kv_cache.v_scale_buffer->shape();
RUNTIME_ASSERT_OP_ARG(
((k_scale_shape.size() == 3) && (v_scale_shape.size() == 3) && \
(k_scale_shape[0] == v_scale_shape[0]) && (k_scale_shape[1] == v_scale_shape[1]) && \
(k_scale_shape[2] == v_scale_shape[2]) && (k_cache_shape[0] == k_scale_shape[0]) && \
(k_scale_shape[1] == params.configs.kv_head_num) && \
(k_scale_shape[2] == params.configs.tokens_per_block)),
"kv scale check buffer failed. k_scale_buffer: %s, v_scale_buffer: %s",
kv_cache.k_scale_buffer->debugString().c_str(),
kv_cache.v_scale_buffer->debugString().c_str());
}
}
// typically local_head_num * size_per_head
const auto qkv_hidden_size = output_weight->kernel->shape()[0];
BufferPtr qkv = nullptr;
if (params.enable_sp && params.layer_id > 0) {
BufferPtr ag_recv_buffer = nullptr;
BufferPtr attn_input_ptr = nullptr;
printBufferData(*attn_input_ptr, "attn_ag_input");
if (params.qscheme == NoQuantize) {
attn_input_ptr = params.input.slice(0, params.input.shape()[0]);
ag_recv_buffer = allocateBuffer({attn_input_ptr->type(), {pad_token_num, attn_input_ptr->shape()[1]}}, {"ag_recv_buffer"});
} else if (params.qscheme == Qint8PerToken){
attn_input_ptr = reinterpret_cast<const QBuffer&>(params.input).qslice(0, params.input.shape()[0]);
BufferPtr kernel = allocateBuffer({attn_input_ptr->type(), {pad_token_num, attn_input_ptr->shape()[1]}}, {"ag_recv_buffer_kernel"});
BufferPtr scales = allocateBuffer({DataType::TYPE_FP32,
{pad_token_num},
AllocationType::DEVICE},
{"ag_recv_buffer_scale"});
ag_recv_buffer = BufferPtr(new QBuffer(std::move(kernel),
std::move(scales),
std::move(BufferPtr(
new Buffer(MemoryType::MEMORY_GPU,
DataType::TYPE_INVALID,
{0},
nullptr)))));
} else if (params.qscheme == Qfp8PerTensor){
attn_input_ptr = reinterpret_cast<const QBuffer&>(params.input).qslicePerTensor(0, params.input.shape()[0]);
BufferPtr kernel = allocateBuffer({attn_input_ptr->type(), {pad_token_num, attn_input_ptr->shape()[1]}}, {"ag_recv_buffer_kernel"});
BufferPtr scales = reinterpret_cast<const QBuffer&>(params.input).scalesPtr();
ag_recv_buffer = BufferPtr(new QBuffer(std::move(kernel),
std::move(scales),
std::move(BufferPtr(
new Buffer(MemoryType::MEMORY_GPU,
DataType::TYPE_INVALID,
{0},
nullptr)))));
} else {
throw OpException({OpErrorType::ERROR_UNIMPLEMENTED, "allGatherloraLinear qscheme type not supported"});
}
GemmParams qkv_gemm_params = GemmParams(*ag_recv_buffer, *(params.weights.qkv_weight->kernel));
AllGatherLoraLinearOutput all_gather_output = allGatherloraLinear({LoraLinearParams(qkv_gemm_params, params.common.lora_input.qkv_lora_input), attn_input_ptr, ag_recv_buffer, params.qscheme, params.output->type()});
// syncAndCheck();
ag_recv_buffer = all_gather_output.all_gather_recv_buffer;
qkv = all_gather_output.output;
printBufferData(*ag_recv_buffer, "attn_ag_inter_output");
printBufferData(*qkv, "attn_ag_final_output");
} else {
// NOTE: Cuda implementation fused adding qkv_weight->bias in invokeAddFusedQKVBiasTranspose kernel call.
// other devices need to be careful about this.
// maybe add a device property here.
qkv = params.configs.use_mla ? mlaQKVGemm(params): mhaQKVGemm(params);
}
printBufferData(*qkv, "qkv");
// attention layer output is preallocated to avoid memory fragmentation
// note that this output is returned and further used as residual
auto qscheme = params.qscheme;
auto dtype = (input.isQBuffer() && qscheme != QScheme::Qfp8PerTensor ? qkv->type() : input.type());
BufferPtr qkv_output = nullptr;
if (qscheme == QScheme::Qfp8PerTensor) {
auto scales = params.weights.static_quant_weight->kernel;
qkv_output = BufferPtr(new QBuffer(allocateBuffer({DataType::TYPE_FP8_E4M3, {pad_token_num, qkv_hidden_size}}, {"qkv_output"}),
BufferPtr(new Buffer(scales->where(), scales->type(), scales->shape(), scales->data())),
BufferPtr(new Buffer(scales->where(), scales->type(), {0}, nullptr))));
} else {
#if defined(__aarch64__)
// Arm attention op only support fp32 data type
qkv_output = allocateBuffer({DataType::TYPE_FP32, {pad_token_num, qkv_hidden_size}}, {"qkv_output"});
#else
qkv_output = allocateBuffer({dtype, {pad_token_num, qkv_hidden_size}}, {"qkv_output"});
#endif
}
auto kv_cache_block_id = layer_kv_cache ? layer_kv_cache->kv_cache_block_id : nullptr;
if (generate_batch_size) {
auto generate_qkv = qkv->view(0, generate_batch_size);
auto generate_output = qkv_output->view(0, generate_batch_size);
if (layer_kv_cache) {
params.common.kv_cache->kv_cache_block_id = kv_cache_block_id->slice(0, generate_batch_size);
}
decoderSelfAttention({params.layer_id, generate_qkv, generate_output, params.common, params.weights, params.configs, params.qscheme});
}
if (context_batch_size) {
auto context_qkv = qkv->view(generate_batch_size, context_token_num);
auto context_output = qkv_output->view(generate_batch_size, context_token_num);
if (layer_kv_cache) {
params.common.kv_cache->kv_cache_block_id = kv_cache_block_id->slice(generate_batch_size, context_batch_size);
}
contextAttention({params.layer_id, context_qkv, context_output, params.common, params.weights, params.configs, params.qscheme});
}
if (layer_kv_cache) {
params.common.kv_cache->kv_cache_block_id = kv_cache_block_id;
}
printBufferData(*qkv_output, "qkv_output");
BufferPtr gemm_output = nullptr;
BufferPtr attn_output = nullptr;
if (params.enable_sp) {
gemm_output = allocateBuffer({qkv->type(), {pad_token_num, output_weight->kernel->shape()[1]}},
{"attn_layer_out"});
attn_output = params.output;
} else {
gemm_output = params.output ? params.output
: allocateBuffer({qkv->type(), {pad_token_num, output_weight->kernel->shape()[1]}},
{"attn_layer_out"});
attn_output = gemm_output;
}
if(params.qscheme == QScheme::Qint8PerTensor || params.qscheme == QScheme::Qint8PerToken) {
OptionalConstBufferRef smoother_weight =
params.weights.smoother_weight ? (OptionalConstBufferRef) * (params.weights.smoother_weight->kernel) :
std::nullopt;
OptionalConstBufferRef shift_weight = (params.weights.shift_weight == nullptr) ?
nullopt :
(OptionalConstBufferRef)*params.weights.shift_weight->kernel;
OptionalConstBufferRef static_scale_weight =
params.weights.static_quant_weight ?
(OptionalConstBufferRef) * (params.weights.static_quant_weight->kernel) :
std::nullopt;
OptionalConstBufferRef static_scale_reciprocal_weight =
params.weights.static_scale_reciprocal_weight ?
(OptionalConstBufferRef) * (params.weights.static_scale_reciprocal_weight->kernel) :
std::nullopt;
auto quant_data_type = DataType::TYPE_INT8;
auto quant_params = QuantizeParams(
*qkv_output,
quant_data_type,
1,
params.qscheme,
smoother_weight,
shift_weight,
static_scale_weight,
static_scale_reciprocal_weight);
qkv_output = quantize(quant_params);
}
#if defined(__aarch64__)
// Arm attention op only support fp32 data type, convert to original dtype
GemmParams output_gemm_params = GemmParams(*qkv_output, *(output_weight->kernel), nullopt, gemm_output, dtype);
#else
GemmParams output_gemm_params = GemmParams(*qkv_output, *(output_weight->kernel), nullopt, gemm_output);
#endif
if (params.enable_sp) {
printBufferData(*qkv_output, "attn_rs_input");
ReduceScatterLoraLinearOutput reduce_scatter_output = loraLinearReduceScatter({LoraLinearParams(output_gemm_params, params.common.lora_input.out_lora_input), params.output, params.qscheme, params.output->type()});
// syncAndCheck();
gemm_output = reduce_scatter_output.reduce_scatter_recv_buffer;
attn_output = reduce_scatter_output.output;
printBufferData(*gemm_output, "attn_rs_inter_output");
printBufferData(*attn_output, "attn_rs_final_output");
} else {
loraLinear(LoraLinearParams(output_gemm_params, params.common.lora_input.out_lora_input));
}
return {std::move(attn_output)};
}