in maga_transformer/cpp/models/GptModel.cc [849:1013]
AttentionBlockOutputs GptModel::forwardAttentionBlock(
const GptLayerInputs& inputs,
const int32_t layer_id,
rtp_llm::lora::LoraModelInputPtr lora_model_input,
const LastLayerDeferedParams& last_layer_defered_params)
{
auto hidden = inputs.hidden;
auto pre_decoder_residual = inputs.pre_decoder_residual;
auto attention_common_inputs = move(inputs.attention_common_inputs);
DevicePerfWrapper wrapper(device_, "attention_block_layer_" + std::to_string(layer_id) + "_bs_" + std::to_string(hidden->shape()[0]));
if (last_layer_defered_params.combine_output) {
printBufferData(*(last_layer_defered_params.combine_output.value().all_output), "layer_" + to_string(layer_id - 1) + "_combine_output_defered");
hidden = device_->gatherCombineOutput(last_layer_defered_params.combine_output.value()).hidden_states;
}
attention_common_inputs.layer_id = layer_id;
const auto& layer = weights_.layers[layer_id];
bool enable_sp = inputs.enable_sp;
// here hidden->dtype maybe int8, so use dytpe of embedding lookup result instead
size_t rank_pad_token_num = enable_sp ? inputs.pad_token_num / device_props_.tp_size : hidden->shape()[0];
BufferPtr attn_out_buf = device_->allocateBuffer({inputs.dtype, {rank_pad_token_num, hidden->shape()[1]}}, {"attn_out_buf"});
if (!enable_sp) {
// Note: for custom all reduce, prepareAllReduce will replace the original attn_out_buf with
// a new custom_ar_comm buffer. Here we must make sure that attn_out_buf is not released or replaced by
// other buffer before the actual allreduce operations. Otherwise, it will raise an error in custom ar.
attn_out_buf = device_->prepareAllReduce({std::move(attn_out_buf), ReduceOp::Sum}).buffer;
}
auto residual = pre_decoder_residual ? pre_decoder_residual : hidden;
printBufferData(*residual, "in residual");
BufferPtr residual2 = nullptr;
BufferPtr hidden_to_slice = nullptr; // for sp and overlap comm type 2
if (layer.pre_layernorm) {
// TODO(wangyin.yx): fuse this clone branch into layernorm(rmsnorm)
residual = last_layer_defered_params.residual ? device_->allocateBufferLike(*hidden, AllocationType::DEVICE, {"residual"})
: device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
int m_split = device_props_.m_split;
size_t overlap_comm_type = device_props_.overlap_comm_type;
auto pre_layernorm_output = device_->layernorm(LayernormParams(hidden,
residual,
*layer.pre_layernorm,
rtp_llm::mayGetRef(last_layer_defered_params.residual),
rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
std::nullopt,
0.f,
description_.layernorm_eps,
(enable_sp && overlap_comm_type == 2) ? false : true,
false,
description_.norm_type,
description_.act_qscheme,
layer_id > 0 ? true: false,
false));
if (enable_sp && layer_id == 0) {
if (overlap_comm_type == 1 && m_split > 0) {
vector<int> selected_indices;
selected_indices.reserve(rank_pad_token_num);
size_t m = inputs.pad_token_num;
size_t m_chunk = m / m_split;
if (m > 128) {
m_chunk = (m / m_split + 127) & ~127;
}
size_t tp_rank = device_props_.tp_rank;
size_t round = m_chunk / device_props_.tp_size;
size_t offset = tp_rank * round;
for (size_t i = 0; i < rank_pad_token_num; i++) {
selected_indices.push_back( (i / round) * m_chunk + i % round + offset);
}
// printBufferData(*vector2Buffer(selected_indices), "selected_indices");
residual = device_->select({*residual, *device_->clone({*vector2Buffer(selected_indices)})});
} else {
hidden_to_slice = residual;
residual = residual->slice(rank_pad_token_num * device_props_.tp_rank, rank_pad_token_num);
}
}
hidden = std::move(pre_layernorm_output.output);
} else if (last_layer_defered_params.residual || last_layer_defered_params.shared_expert_output) {
// NOTE(wangyin): this branch is not used for now, might be errornous
residual = device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
auto prev_ffn_layernorm_output = device_->layernorm({
hidden,
nullptr,
std::nullopt, // post_ffn_layernorm_weights
rtp_llm::mayGetRef(last_layer_defered_params.residual),
rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
});
hidden = std::move(prev_ffn_layernorm_output.output);
}
if (k_cache_buffer_ && attention_common_inputs.kv_cache) {
attention_common_inputs.kv_cache->k_cache_buffer = k_cache_buffer_->index(layer_id);
attention_common_inputs.kv_cache->v_cache_buffer = v_cache_buffer_->index(layer_id);
if (k_scale_buffer_) {
attention_common_inputs.kv_cache->k_scale_buffer = k_scale_buffer_->index(layer_id);
attention_common_inputs.kv_cache->v_scale_buffer = v_scale_buffer_->index(layer_id);
}
}
if (lora_model_input) {
attention_common_inputs.lora_input = lora_model_input->getAttentionLayerLoraInput(layer_id);
}
AttentionLayerOutput attn_output;
auto attn_params = AttentionLayerParams({
layer_id,
*hidden,
move(attn_out_buf),
description_.attention_conf,
layer.self_attention_weights,
attention_common_inputs,
device_props_.attn_fuse_add_residual ? (OptionalConstBufferRef)*residual : nullopt,
{description_.layernorm_eps, description_.norm_type},
description_.act_qscheme,
enable_sp,
inputs.pad_token_num
});
if (description_.attention_conf.use_mla && device_->mla_ops_type != rtp_llm::MlaOpsType::MHA) {
attn_output = device_->mlaAttentionLayer(attn_params);
} else {
attn_output = device_->attentionLayer(attn_params);
}
auto attn_hidden = std::move(attn_output.hidden_states);
if (device_props_.tp_size > 1 && !enable_sp) {
// Note: for custom all reduce, allReduce will allocate a new buffer and replace the original attn_hidden with it
auto wrapper = DevicePerfWrapper(device_, "allReduce, sizeBytes=%ld", (long)attn_hidden->sizeBytes());
attn_hidden = device_->allReduce({std::move(attn_hidden), ReduceOp::Sum}).buffer;
}
if (residual_scale_) {
attn_hidden = device_->multiply({*residual_scale_, *attn_hidden});
}
printBufferData(*attn_hidden, "layer_" + to_string(layer_id) + "_attn_output");
if (layer.post_layernorm) {
// attn_hidden = attn_hidden + residual
// hidden = layernorm(attn_hidden)
printBufferData(*residual, "before post layernorm residual");
auto post_layernorm_params = LayernormParams(attn_hidden,
attn_hidden,
rtp_llm::mayGetRef(layer.post_layernorm),
device_props_.attn_fuse_add_residual ? nullopt : (OptionalConstBufferRef)*residual,
nullopt,
rtp_llm::mayGetRef(layer.self_attention_weights.output_weight->bias),
0.f,
description_.layernorm_eps,
false,
description_.post_layernorm,
description_.norm_type,
description_.act_qscheme,
false,
true);
auto post_layernorm_output = device_->layernorm(post_layernorm_params);
hidden = std::move(post_layernorm_output.output);
attn_hidden = std::move(post_layernorm_output.before_norm_output);
residual = attn_hidden;
printBufferData(*residual, "after post layernorm residual");
} else {
residual2 = attn_hidden;
}
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_input");
return {hidden, residual, residual2};
}