maga_transformer/cpp/models/GptModel.cc (1,224 lines of code) (raw):
#include "maga_transformer/cpp/models/GptModel.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/core/Types.h"
#include "maga_transformer/cpp/devices/OpData.h"
#include "maga_transformer/cpp/core/BufferHelper.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
#include "maga_transformer/cpp/models_weight/W.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/utils/StringUtil.h"
#include "maga_transformer/cpp/devices/utils/DevicePerfWrapper.h"
#include <algorithm>
#include <memory>
using namespace std;
using namespace rtp_llm;
namespace rtp_llm {
string GptModelInputs::debugString() const {
if (!Logger::getEngineLogger().isDebugMode()) {
return "";
}
std::stringstream debug_string;
debug_string << "GptModelInputs { " << "combo_tokens: " << combo_tokens->debugStringWithData<int32_t>()
<< ", input_lengths: " << input_lengths->debugStringWithData<int32_t>()
<< ", sequence_lengths: " << sequence_lengths->debugStringWithData<int32_t>()
<< ", prefix_lengths: " << prefix_lengths->debugStringWithData<int32_t>();
if (combo_position_ids) {
debug_string << ", combo_position_ids: " << combo_position_ids->debugStringWithData<int32_t>();
}
if (lora_ids) {
debug_string << ", lora_ids: " << lora_ids->debugStringWithData<int32_t>();
}
if (lora_input_lengths) {
debug_string << ", lora_input_lengths: " << lora_input_lengths->debugStringWithData<int32_t>();
}
if (kv_cache_block_id) {
debug_string << ", kv_cache_block_id: " << kv_cache_block_id->debugStringWithData<int32_t>();
}
if (attention_mask) {
debug_string << ", attention_mask: " << attention_mask->debugString();
}
if (request_id) {
debug_string << ", request_id: " << request_id->debugStringWithData<int64_t>();
}
if (request_pd_separation) {
debug_string << ", request_pd_separation: " << request_pd_separation->debugStringWithData<bool>();
}
if (cache_keys) {
debug_string << ", cache_keys: " << cache_keys->debugStringWithData<int64_t>();
}
debug_string << ", k block_size: " << k_block_size;
debug_string << ", v block_size: " << v_block_size;
debug_string << ", pd_separation: " << pd_separation;
debug_string << "}";
return debug_string.str();
}
GptModel::GptModel(const GptModelInitParams& params):
device_(params.device),
device_props_(params.device->getDeviceProperties()),
layer_num_(params.weights.layers.size()),
description_(params.description),
weights_(params.weights) {
if (params.kv_cache_buffer) {
k_cache_buffer_ = params.kv_cache_buffer->k_blocks;
v_cache_buffer_ = params.kv_cache_buffer->v_blocks;
if (params.kv_cache_buffer->k_scale) {
k_scale_buffer_ = params.kv_cache_buffer->k_scale;
v_scale_buffer_ = params.kv_cache_buffer->v_scale;
}
}
if (abs(description_.residual_scalar - 1.0) > 1e-6) {
residual_scale_fp32_ = device_->clone({*vector2Buffer(vector<float>{(float)description_.residual_scalar})});
residual_scale_ = residual_scale_fp32_;
}
if (params.description.ffn_conf.moe_configs.has_value()) {
auto moe_conf = params.description.ffn_conf.moe_configs.value();
overall_expert_stats_ = device_->createMoeExpertStates(
{layer_num_, moe_conf.ep_size, moe_conf.expert_num, moe_conf.expert_num + moe_conf.extra_expert_num});
}
}
void getPaddingOffsetAndCuSeqLens(int32_t* padding_offset,
int32_t* cu_seqlens,
const int32_t* sequence_length,
const int32_t* prefix_length,
const int32_t batch_size,
const int32_t max_seq_len)
{
// do cumulated sum
int32_t total_seq_len = 0;
int32_t cum_offset = 0;
int32_t index = 0;
for (int32_t i = 0; i < batch_size; i++) {
int32_t seq_len = sequence_length[i];
if (prefix_length) {
seq_len += prefix_length[i];
}
cu_seqlens[i] = total_seq_len;
if (padding_offset) {
for (int32_t j = 0; j < seq_len; j++) {
padding_offset[index] = cum_offset;
index++;
}
}
cum_offset += max_seq_len - seq_len;
total_seq_len += seq_len;
}
cu_seqlens[batch_size] = total_seq_len;
}
void checkKvBlocksShape(const BufferPtr& input_kv_offset) {
if (!input_kv_offset) {
return;
}
RUNTIME_ASSERT_OP_ARG(
input_kv_offset->shape().size() == 2,
"kv_cache_blocks shape should be [batch_size, block_length].");
}
BufferPtr GptModel::tpSyncEmbeddingOrLogits(const BufferPtr& buffer) {
const auto tp_size = device_props_.tp_size;
const auto tp_rank = device_props_.tp_rank;
const auto buffer_shape = buffer->shape();
const auto local_size = buffer->size();
auto all_data = device_->allocateBuffer({buffer->type(), {buffer_shape[0], buffer_shape[1] * tp_size}});
auto buffer_view = buffer->reshape({buffer->size()});
auto all_data_1d = all_data->reshape({all_data->size()});
device_->copy({all_data_1d.view(local_size * tp_rank, local_size), buffer_view});
device_->allGather({{all_data}});
auto ret = device_->transpose({all_data->reshape({tp_size, buffer_shape[0], buffer_shape[1]})});
ret->updateShape({buffer_shape[0], buffer_shape[1] * tp_size});
return ret;
}
rtp_llm::AttentionCommonInputs GptModel::prepareAttentionInputs(
const GptModelInputs& inputs,
rtp_llm::DataType attn_dtype,
rtp_llm::BufferPtr combo_position_ids)
{
AttentionCommonInputs attention_inputs({
device_->clone({*inputs.input_lengths}),
device_->clone({*inputs.sequence_lengths}),
});
attention_inputs.position_ids = combo_position_ids;
attention_inputs.warmup = inputs.warmup;
if (!inputs.warmup && inputs.pd_separation) {
RTP_LLM_CHECK_WITH_INFO(inputs.input_lengths && inputs.prefix_lengths && inputs.kv_cache_block_id, "failed to get information for pd seperation store cache");
CacheStoreInputs cache_store_inputs({inputs.input_lengths, inputs.prefix_lengths, inputs.kv_cache_block_id});
attention_inputs.cache_store_inputs = cache_store_inputs;
}
if (inputs.kv_cache_block_id) {
checkKvBlocksShape(inputs.kv_cache_block_id);
KvCacheInfo kv_cache;
kv_cache.layer_num = layer_num_;
kv_cache.kv_cache_block_id = device_->clone({*inputs.kv_cache_block_id, AllocationType::DEVICE, {"kv_cache_block_id"}});
attention_inputs.kv_cache = kv_cache;
}
const auto& input_lengths = inputs.input_lengths;
const auto& sequence_lengths = inputs.sequence_lengths;
const auto& prefix_lengths = inputs.prefix_lengths;
const auto decoder_batch_size = sequence_lengths->shape()[0];
const auto context_batch_size = input_lengths->shape()[0] - decoder_batch_size;
const auto max_context_seq_len = context_batch_size ? *std::max_element(
input_lengths->data<int32_t>() + decoder_batch_size,
input_lengths->data<int32_t>() + decoder_batch_size + context_batch_size) : 0;
RTP_LLM_CHECK_WITH_INFO(!prefix_lengths || prefix_lengths->size() == context_batch_size, "prefix_lengths size %d is not equal to context batch size %d.", prefix_lengths->size(), context_batch_size);
attention_inputs.max_prefix_length = context_batch_size && prefix_lengths ? *std::max_element(
prefix_lengths->data<int32_t>(),
prefix_lengths->data<int32_t>() + prefix_lengths->size()) : 0;
const auto max_decoder_seq_len = decoder_batch_size ? *std::max_element(
sequence_lengths->data<int32_t>(),
sequence_lengths->data<int32_t>() + decoder_batch_size) : 0;
std::vector<int32_t> cu_seqlens_data(context_batch_size + 1);
std::vector<int32_t> padding_offset_data(inputs.combo_tokens->shape()[0]);
getPaddingOffsetAndCuSeqLens(
padding_offset_data.data(),
cu_seqlens_data.data(),
input_lengths->dataWithOffset<int32_t>(decoder_batch_size),
nullptr,
context_batch_size,
max_context_seq_len);
// RUNTIME_ASSERT_OP_ARG(
// (cu_seqlens_data[context_batch_size] + decoder_batch_size == inputs.combo_tokens->shape()[0]),
// "combo_tokens is not consistent with input lengths, "
// "there are %d tokens in context plus %ld tokens in decoder batch, but got %ld input tokens.",
// cu_seqlens_data[context_batch_size], decoder_batch_size, inputs.combo_tokens->shape()[0]);
attention_inputs.cu_seqlens = device_->clone(
{*vector2Buffer(cu_seqlens_data), AllocationType::DEVICE, {"cu_seqlens"}});
if (attention_inputs.max_prefix_length) {
attention_inputs.prefix_prompt_lengths = device_->clone(*prefix_lengths);
std::vector<int32_t> cu_kv_seqlens_data(context_batch_size + 1);
getPaddingOffsetAndCuSeqLens(
nullptr,
cu_kv_seqlens_data.data(),
input_lengths->dataWithOffset<int32_t>(decoder_batch_size),
prefix_lengths->data<int32_t>(),
context_batch_size,
max_context_seq_len);
attention_inputs.cu_kv_seqlens = device_->clone(
{*vector2Buffer(cu_kv_seqlens_data), AllocationType::DEVICE, {"cu_kv_seqlens"}});
} else {
attention_inputs.cu_kv_seqlens = attention_inputs.cu_seqlens;
}
attention_inputs.padding_offset = device_->clone(
{*vector2Buffer(padding_offset_data), AllocationType::DEVICE, {"padding_offset"}});
attention_inputs.decoder_batch_size = decoder_batch_size;
attention_inputs.context_batch_size = context_batch_size;
attention_inputs.context_max_seq_len = max_context_seq_len;
attention_inputs.decoder_max_seq_len = max_decoder_seq_len;
attention_inputs.context_token_num = cu_seqlens_data[context_batch_size];
if (weights_.linear_bias_slopes) {
attention_inputs.linear_bias_slopes = weights_.linear_bias_slopes->kernel;
}
RTP_LLM_LOG_DEBUG("prepare model run sequence lengths: %s, input_lengths: %s, kv cache: %s, context batch size: %ld, decoder batch size: %ld",
inputs.sequence_lengths->debugStringWithData<int32_t>().c_str(),
inputs.input_lengths->debugStringWithData<int32_t>().c_str(),
inputs.kv_cache_block_id ? inputs.kv_cache_block_id->debugString().c_str() : "NULL",
context_batch_size, decoder_batch_size);
auto prep_output = device_->prepareModelRun({
description_.attention_conf,
inputs.prefix_lengths,
inputs.sequence_lengths,
inputs.input_lengths,
inputs.kv_cache_block_id,
attention_inputs.kv_cache ? attention_inputs.kv_cache->kv_cache_block_id : nullptr,
attn_dtype,
context_batch_size,
decoder_batch_size,
(bool)k_cache_buffer_,
attention_inputs.max_prefix_length > 0,
(bool)weights_.linear_bias_slopes
});
if (inputs.cache_keys) {
vector<int64_t> cache_keys_vec = rtp_llm::buffer2vector<int64_t>(*inputs.cache_keys);
attention_inputs.cache_keys = transVectorToString(cache_keys_vec);
}
attention_inputs.decode_flash_infer_attn_params.swap(prep_output.decode_flash_infer_attn_params);
attention_inputs.prefill_flash_infer_attn_params.swap(prep_output.prefill_flash_infer_attn_params);
attention_inputs.request_id = inputs.request_id;
attention_inputs.request_pd_separation = inputs.request_pd_separation;
attention_inputs.k_block_size = inputs.k_block_size;
attention_inputs.v_block_size = inputs.v_block_size;
attention_inputs.scale_block_size = inputs.scale_block_size;
attention_inputs.pd_separation = inputs.pd_separation;
if (context_batch_size && prep_output.need_mask) {
attention_inputs.attention_mask = device_->attentionMask({
inputs.input_lengths->view(decoder_batch_size, context_batch_size),
*inputs.prefix_lengths,
attn_dtype,
description_.attention_conf.mask_type == rtp_llm::AttentionMaskType::causalMask
});
}
return attention_inputs;
}
MicroBatchPlan GptModel::planMicroBatches(const GptModelInputs& inputs) {
if (!int(device_props_.enable_layer_micro_batch)) {
RTP_LLM_LOG_DEBUG("micro batch disable when enable_layer_micro_batch is false");
return {false, {}};
}
if (layer_num_ == 1) {
return {false, {}};
}
const auto& input_lengths = inputs.input_lengths;
const auto& sequence_lengths = inputs.sequence_lengths;
const auto decoder_batch_size = sequence_lengths->shape()[0];
const auto context_batch_size = input_lengths->shape()[0] - decoder_batch_size;
if (decoder_batch_size + context_batch_size < 2) {
RTP_LLM_LOG_DEBUG("micro batch disable when batch size %ld is less than 2", decoder_batch_size + context_batch_size);
return {false, {}};
}
// TODO: design better split strategy that consider the computational workload of each request
// disable micro batching if both context and decoder query exists.
if (context_batch_size && decoder_batch_size) {
RTP_LLM_LOG_INFO("split context in micro batch 0, decode in micro batch 1 disabled!");
return {false, {}};
}
const auto batch_size_to_split = context_batch_size ? context_batch_size : decoder_batch_size;
const auto micro_batch_0_size = (batch_size_to_split + 1) / 2;
const auto micro_batch_1_size = batch_size_to_split - micro_batch_0_size;
RTP_LLM_LOG_INFO("split micro batch size %ld, %ld", micro_batch_0_size, micro_batch_1_size);
return context_batch_size ? MicroBatchPlan{true, {{micro_batch_0_size, 0}, {micro_batch_1_size, 0}}}
: MicroBatchPlan{true, {{0, micro_batch_0_size}, {0, micro_batch_1_size}}};
}
vector<LayerMicroBatchInputs> GptModel::prepareMicroBatchInputs(
const GptModelInputs& inputs,
const BufferPtr& hidden,
const BufferPtr& pre_decoder_residual,
const rtp_llm::DataType attn_dtype,
const MicroBatchPlan& micro_batch_plan)
{
vector<LayerMicroBatchInputs> micro_batch_inputs;
size_t sliced_token_idx = 0;
size_t sliced_batch_idx = 0; // for input_lengths and kv cache block id
size_t decode_batch_idx = 0; // for sequence_lengths
size_t prefill_batch_idx = 0; // for lm_output_indexes and prefix_lengths
if (!micro_batch_plan.enable) {
RTP_LLM_LOG_DEBUG("micro batch disable when enable is false, use fake");
// we put everything into the first micro batch, and send empty query to the second micro batch
auto attention_common_inputs = prepareAttentionInputs(inputs, attn_dtype, nullptr);
micro_batch_inputs.push_back({hidden, pre_decoder_residual, attention_common_inputs});
// The fake query
GptModelInputs fake_inputs;
fake_inputs.kv_cache_block_id = nullptr;
fake_inputs.combo_tokens = inputs.combo_tokens->slice(0, 1);
fake_inputs.input_lengths = device_->allocateBuffer({DataType::TYPE_INT32, {1}, AllocationType::HOST});
fake_inputs.input_lengths->data<int32_t>()[0] = 1;
fake_inputs.sequence_lengths = device_->allocateBuffer({DataType::TYPE_INT32, {0}, AllocationType::HOST});
fake_inputs.prefix_lengths = device_->allocateBuffer({DataType::TYPE_INT32, {1}, AllocationType::HOST});
fake_inputs.prefix_lengths->data<int32_t>()[0] = 0;
auto fake_hidden = device_->allocateBuffer({hidden->type(), {1, hidden->shape()[1]}});
auto attention_common_inputs_fake = prepareAttentionInputs(fake_inputs, attn_dtype, nullptr);
micro_batch_inputs.push_back({move(fake_hidden), nullptr, move(attention_common_inputs_fake), true});
} else {
// TODO(wangyin.yx): refact this splitting method, extract common code
for (size_t i = 0; i < micro_batch_plan.batch_infos.size(); ++i) {
const auto& p_micro_batch_size = micro_batch_plan.batch_infos[i].prefill_num;
const auto& d_micro_batch_size = micro_batch_plan.batch_infos[i].decoder_num;
RUNTIME_ASSERT_OP_ARG(!(p_micro_batch_size && d_micro_batch_size),
"one micro batch can not contain both p and d tokens, but got %ld and %ld",
p_micro_batch_size, d_micro_batch_size);
RTP_LLM_LOG_DEBUG("micro batch index %ld, prefill size %ld, decode size %ld",
i, p_micro_batch_size, d_micro_batch_size);
if (d_micro_batch_size) {
GptModelInputs micro_model_inputs = inputs;
RTP_LLM_LOG_DEBUG("d slice from %ld %ld %ld", sliced_token_idx, sliced_batch_idx, decode_batch_idx);
micro_model_inputs.combo_tokens = inputs.combo_tokens->slice(sliced_token_idx, d_micro_batch_size);
micro_model_inputs.input_lengths = inputs.input_lengths->slice(sliced_batch_idx, d_micro_batch_size);
micro_model_inputs.sequence_lengths = inputs.sequence_lengths->slice(decode_batch_idx, d_micro_batch_size);
micro_model_inputs.attention_mask = inputs.attention_mask ? inputs.attention_mask->slice(sliced_batch_idx, d_micro_batch_size) : nullptr;
micro_model_inputs.kv_cache_block_id = inputs.kv_cache_block_id->slice(sliced_batch_idx, d_micro_batch_size);
auto micro_hidden = hidden->slice(sliced_token_idx, d_micro_batch_size);
auto micro_pre_decoder_residual = pre_decoder_residual ? pre_decoder_residual->slice(sliced_token_idx, d_micro_batch_size) : nullptr;
auto attention_common_inputs = prepareAttentionInputs(micro_model_inputs, attn_dtype, nullptr);
micro_batch_inputs.push_back({
move(micro_hidden), move(micro_pre_decoder_residual), move(attention_common_inputs)});
sliced_token_idx += d_micro_batch_size;
sliced_batch_idx += d_micro_batch_size;
decode_batch_idx += d_micro_batch_size;
RTP_LLM_LOG_DEBUG("micro batch %ld sliced decode, batch idx %ld, token idx %ld",
i, sliced_batch_idx, sliced_token_idx);
} else {
GptModelInputs micro_model_inputs = inputs;
RTP_LLM_LOG_DEBUG("p slice from %ld %ld %ld", sliced_token_idx, sliced_batch_idx, prefill_batch_idx);
micro_model_inputs.input_lengths = inputs.input_lengths->slice(sliced_batch_idx, p_micro_batch_size);
micro_model_inputs.kv_cache_block_id = inputs.kv_cache_block_id->slice(sliced_batch_idx, p_micro_batch_size);
micro_model_inputs.lm_output_indexes = inputs.lm_output_indexes->slice(prefill_batch_idx, p_micro_batch_size);
micro_model_inputs.prefix_lengths = inputs.prefix_lengths->slice(prefill_batch_idx, p_micro_batch_size);
micro_model_inputs.attention_mask = inputs.attention_mask ? inputs.attention_mask->slice(sliced_batch_idx, p_micro_batch_size) : nullptr;
auto slice_token_num = std::accumulate(
micro_model_inputs.input_lengths->data<int32_t>(),
micro_model_inputs.input_lengths->data<int32_t>() + p_micro_batch_size,
0);
micro_model_inputs.combo_tokens = inputs.combo_tokens->slice(sliced_token_idx, slice_token_num);
micro_model_inputs.request_id = inputs.request_id ? inputs.request_id->slice(sliced_batch_idx, p_micro_batch_size) : nullptr;
micro_model_inputs.request_pd_separation = inputs.request_pd_separation ? inputs.request_pd_separation->slice(sliced_batch_idx, p_micro_batch_size) : nullptr;
micro_model_inputs.cache_keys = inputs.cache_keys ? inputs.cache_keys->slice(sliced_batch_idx, p_micro_batch_size) : nullptr;
auto micro_hidden = hidden->slice(sliced_token_idx, slice_token_num);
auto micro_pre_decoder_residual = pre_decoder_residual ? pre_decoder_residual->slice(sliced_token_idx, slice_token_num) : nullptr;
auto attention_common_inputs = prepareAttentionInputs(micro_model_inputs, attn_dtype, nullptr);
micro_batch_inputs.push_back({
move(micro_hidden), move(micro_pre_decoder_residual), move(attention_common_inputs)});
sliced_token_idx += slice_token_num;
sliced_batch_idx += p_micro_batch_size;
prefill_batch_idx += p_micro_batch_size;
RTP_LLM_LOG_DEBUG("micro batch %ld sliced context, batch idx %ld, token idx %ld",
i, sliced_batch_idx, sliced_token_idx);
}
}
}
return micro_batch_inputs;
}
rtp_llm::BufferPtr GptModel::embeddingPost(const BufferPtr& hidden_states, const GptModelInputs& inputs) {
return hidden_states;
};
GptLayerInputs GptModel::forwardPreLayers(const GptModelInputs& inputs) {
DevicePerfWrapper wrapper(device_, "forwardPreLayers");
bool enable_sp = device_->getDeviceProperties().enable_sp;
size_t token_num = inputs.combo_tokens->shape()[0];
size_t pad_token_num = token_num;
size_t pad_mod_num = device_props_.tp_size * max((size_t)1, device_props_.m_split);
if (token_num <= pad_mod_num) {
enable_sp = false;
}
if (enable_sp && token_num % pad_mod_num != 0) {
pad_token_num = token_num + (pad_mod_num - token_num % pad_mod_num);
BufferPtr combo_tokens = inputs.combo_tokens;
BufferPtr pad_combo_tokens = device_->allocateBuffer({combo_tokens->type(), {pad_token_num}, AllocationType::HOST},{"pad_combo_tokens"});
device_->bufMemset(*pad_combo_tokens, 0);
device_->copy({pad_combo_tokens->view(0, token_num), *combo_tokens});
inputs.combo_tokens = pad_combo_tokens;
printBufferData(*combo_tokens, {"combo_tokens"});
printBufferData(*pad_combo_tokens, {"pad_combo_tokens"});
}
const auto combo_tokens = device_->clone(
{*inputs.combo_tokens, AllocationType::DEVICE, {"combo_tokens"}});
const auto& embedding_table = weights_.embedding->kernel;
const BufferPtr combo_position_ids = inputs.combo_position_ids ? device_->clone({*inputs.combo_position_ids}): nullptr;
const BufferPtr combo_tokens_type_ids = inputs.combo_tokens_type_ids ? device_->clone({*inputs.combo_tokens_type_ids}): nullptr;
const BufferPtr text_tokens_mask = inputs.multimodal_features ?
device_->clone({*inputs.text_tokens_mask, AllocationType::DEVICE, {"text_tokens_mask"}}) : nullptr;
const BufferPtr mm_feature_locs = inputs.mm_features_locs ? inputs.mm_features_locs: nullptr;
// word embedding lookup
auto hidden = device_->embeddingLookup({
*combo_tokens, *embedding_table, description_.input_embedding_scalar,
text_tokens_mask ? (OptionalConstBufferRef)*text_tokens_mask : nullopt,
combo_position_ids ? (OptionalConstBufferRef)*combo_position_ids: nullopt,
weights_.position_encoding ? (OptionalConstBufferRef)*weights_.position_encoding->kernel: nullopt,
combo_tokens_type_ids ? (OptionalConstBufferRef)*combo_tokens_type_ids: nullopt,
weights_.token_type_embedding ? (OptionalConstBufferRef)*weights_.token_type_embedding->kernel: nullopt});
if (residual_scale_fp32_ && residual_scale_->type() != hidden->type()) {
residual_scale_ = device_->convert({residual_scale_fp32_, hidden->type()});
}
if (device_props_.tp_size > 1) {
hidden = tpSyncEmbeddingOrLogits(hidden);
}
hidden = embeddingPost(hidden, inputs);
auto hidden_dtype = hidden->type();
auto attn_dtype = hidden_dtype;
if (description_.act_qscheme == QScheme::Qfp8PerTensor) {
attn_dtype = DataType::TYPE_QFP8_E4M3;
}
// pre layernorm
BufferPtr pre_decoder_residual = nullptr;
if (description_.act_qscheme != QScheme::NoQuantize && weights_.pre_decoder_layernorm) {
pre_decoder_residual = device_->allocateBufferLike(*hidden);
}
printBufferData(*hidden, "before decoder layernorm hidden");
if (weights_.pre_decoder_layernorm) {
auto decoder_input = device_->layernorm(LayernormParams(hidden,
pre_decoder_residual,
*weights_.pre_decoder_layernorm,
nullopt,
nullopt,
nullopt,
0.f,
description_.layernorm_eps,
true,
pre_decoder_residual != nullptr,
description_.norm_type,
description_.act_qscheme));
hidden = std::move(decoder_input.output);
}
if (inputs.multimodal_features) {
hidden = device_->multimodalEmbedding({
hidden,
(OptionalConstVecBufferPtrRef)inputs.multimodal_features,
mm_feature_locs ? (OptionalConstBufferRef)*mm_feature_locs: nullopt
});
}
printBufferData(*hidden, "input_hidden");
if (device_props_.overlap_comm_type == 2) {
const auto& layer0 = weights_.layers[0];
RTP_LLM_CHECK_WITH_INFO(description_.act_qscheme == QScheme::NoQuantize || description_.act_qscheme == QScheme::Qint8PerToken || description_.act_qscheme == Qfp8PerTensor,
"ring p2p overlap only supports bf16/fp16 or w8a8 or fp8 per block");
const size_t max_batch_seq_len = autil::EnvUtil::getEnv("MAX_CONTEXT_BATCH_SIZE", 1) * device_->initParams().max_seq_len;
const size_t attn_rs_hidden = layer0.self_attention_weights.output_weight->kernel->shape()[1];
const size_t ffn_rs_hidden = layer0.ffn_weights.down_weight->kernel->shape()[1];
const size_t attn_ag_hidden = layer0.self_attention_weights.qkv_weight->kernel->shape()[0];
const size_t ffn_ag_hidden = layer0.ffn_weights.gate_weight->kernel->shape()[0];
DataType rs_output_type = hidden->type();
DataType ag_input_type = attn_dtype;
bool enable_per_token_scale = description_.act_qscheme == QScheme::Qint8PerToken;
bool enable_ffn_tp = enable_sp && device_props_.ffn_tp_size > 1;
device_->prepareCommBuffer({max_batch_seq_len, attn_rs_hidden, ffn_rs_hidden, attn_ag_hidden, ffn_ag_hidden, rs_output_type, ag_input_type, enable_per_token_scale, enable_ffn_tp});
}
auto micro_batch_plan = planMicroBatches(inputs);
if (int(device_props_.enable_layer_micro_batch)) {
auto micro_batch_inputs = prepareMicroBatchInputs(
inputs, hidden, pre_decoder_residual, attn_dtype, micro_batch_plan);
return {move(hidden), move(pre_decoder_residual), AttentionCommonInputs(), hidden_dtype, micro_batch_inputs, enable_sp, token_num, pad_token_num};
} else {
// prepare resources for all layers
auto attention_common_inputs = prepareAttentionInputs(inputs, attn_dtype, combo_position_ids);
return {move(hidden), move(pre_decoder_residual), move(attention_common_inputs), hidden_dtype, {}, enable_sp, token_num, pad_token_num};
}
}
vector<GptLayerInputs> GptModel::forwardPrefillMicroBatchedLayers(vector<GptLayerInputs> micro_batch_layer_inputs) {
std::vector<LastLayerDeferedParams> last_layer_defered_params(micro_batch_layer_inputs.size());
for (int32_t i = 0; i < layer_num_; ++i) {
const auto& layer = weights_.layers[i];
bool moe_layer = weights_.layers[i].ffn_weights.moe_gate_weight != nullptr;
// dense layer does not need micro batching.
if (!moe_layer) {
for (auto& layer_input : micro_batch_layer_inputs) {
auto layer_outputs = forwardGptLayer(layer_input, i, nullptr);
layer_input.hidden = move(layer_outputs.hidden);
}
continue;
}
std::vector<EpFfnInputs> ep_inputs;
for (size_t micro_batch_idx = 0; micro_batch_idx < micro_batch_layer_inputs.size(); ++micro_batch_idx) {
auto& layer_input = micro_batch_layer_inputs[micro_batch_idx];
auto batch_ep_input = forwardAttentionAndMoeGate(
layer_input, last_layer_defered_params[micro_batch_idx], i, micro_batch_idx);
ep_inputs.push_back(move(batch_ep_input));
}
std::vector<EpFfnOutputs> ep_outputs;
for (size_t micro_batch_idx = 0; micro_batch_idx < micro_batch_layer_inputs.size(); ++micro_batch_idx) {
DevicePerfWrapper wrapper(device_, "mb_moe_layer_" + std::to_string(i) + "_idx_" + std::to_string(micro_batch_idx));
// auto& layer_input = micro_batch_layer_inputs[micro_batch_idx];
auto& batch_ep_input = ep_inputs[micro_batch_idx];
const auto& ffn_params = batch_ep_input.moe_ffn_params;
const auto& dispatched_output = batch_ep_input.dispatch_output;
const auto& moe_conf = ffn_params.configs.moe_configs.value();
auto hidden_states = dispatched_output.hidden;
auto moe_ffn_params = FfnLayerParams(
{*hidden_states, ffn_params.configs, ffn_params.weights, ffn_params.residual, ffn_params.qscheme});
prepareExpertStats(i, moe_ffn_params);
hidden_states = device_->moeFfn(
moe_ffn_params,
{dispatched_output.expert_ids, dispatched_output.expert_scales, dispatched_output.deep_ep_ll_output}
).hidden_states;
// shared experts to overlap combine
if (micro_batch_idx) {
auto shared_expert_output = device_->moeSharedExpert(ep_inputs[micro_batch_idx].moe_ffn_params).hidden_states;
ep_inputs[micro_batch_idx].shared_expert_output = shared_expert_output;
}
if (last_comm_hook_) {
last_comm_hook_->hook_sync();
last_comm_hook_ = nullptr;
}
printBufferData(*hidden_states, "layer_" + to_string(i) + "_combine_input");
auto combine_out = device_->epCombine({
hidden_states,
dispatched_output.indices,
ffn_params.output,
dispatched_output.input_split_sizes,
dispatched_output.output_split_sizes,
moe_conf,
ffn_params.input.shape()[0],
device_props_.enable_comm_overlap,
dispatched_output.deep_ep_output,
dispatched_output.deep_ep_ll_output,
std::make_shared<MoeGateSelectOutput>(batch_ep_input.gate_output),
dispatched_output.expert_ids,
dispatched_output.expert_scales,
});
printBufferData(*combine_out.all_output, "layer_" + to_string(i) + "_combine_output");
auto hook = nullptr;
if (combine_out.comm_barrier_hook) {
last_comm_hook_ = move(combine_out.comm_barrier_hook);
} else {
RTP_LLM_LOG_DEBUG("no combine barrier for layer %ld, micro batch %ld", i, micro_batch_idx);
}
auto output = combine_out.all_output;
ep_outputs.push_back(EpFfnOutputs({output, move(combine_out), move(hook)}));
}
for (size_t micro_batch_idx = 0; micro_batch_idx < micro_batch_layer_inputs.size(); ++micro_batch_idx) {
// last layer: add residual and shared expert output
auto& layer_input = micro_batch_layer_inputs[micro_batch_idx];
auto& batch_ep_input = ep_inputs[micro_batch_idx];
auto& batch_ep_output = ep_outputs[micro_batch_idx];
if (i == layer_num_ - 1) {
if (last_comm_hook_) {
last_comm_hook_->hook_sync();
last_comm_hook_ = nullptr;
}
auto output = batch_ep_output.hidden;
output = device_->gatherCombineOutput(batch_ep_output.combine_output).hidden_states;
printBufferData(*output, "layer_" + to_string(i) + "_ffn_output");
auto ffn_layernorm_output = device_->layernorm({
output,
nullptr,
rtp_llm::mayGetRef(layer.post_ffn_layernorm),
rtp_llm::mayGetRef(batch_ep_input.residual),
rtp_llm::mayGetRef(batch_ep_input.shared_expert_output),
nullopt,
1.0f,
description_.layernorm_eps,
true,
description_.post_layernorm,
description_.norm_type,
QScheme::NoQuantize
});
layer_input.hidden = move(ffn_layernorm_output.output);
printBufferData(*layer_input.hidden, "layer_" + to_string(i) + "_final_hidden");
} else {
// not last layer: defer add residual and bias to next layer
last_layer_defered_params[micro_batch_idx].residual = batch_ep_input.residual;
last_layer_defered_params[micro_batch_idx].shared_expert_output = batch_ep_input.shared_expert_output;
last_layer_defered_params[micro_batch_idx].post_ffn_layernorm_weights = layer.post_ffn_layernorm;
if (last_layer_defered_params[micro_batch_idx].combine_output) {
last_layer_defered_params[micro_batch_idx].combine_output.value().params.expert_ids = nullptr;
last_layer_defered_params[micro_batch_idx].combine_output.value().params.expert_scales = nullptr;
last_layer_defered_params[micro_batch_idx].combine_output = nullopt;
}
last_layer_defered_params[micro_batch_idx].combine_output = move(batch_ep_output.combine_output);
last_layer_defered_params[micro_batch_idx].comm_barrier_hook = move(batch_ep_output.comm_barrier_hook);
layer_input.hidden = move(batch_ep_output.hidden);
}
}
}
return micro_batch_layer_inputs;
}
vector<GptLayerInputs> GptModel::forwardDecodeMicroBatchedLayers(vector<GptLayerInputs> micro_batch_layer_inputs) {
std::vector<LastLayerDeferedParams> last_layer_defered_params_vec(micro_batch_layer_inputs.size());
for (int32_t i = 0; i < layer_num_; ++i) {
const auto& layer = weights_.layers[i];
bool moe_layer = layer.ffn_weights.moe_gate_weight != nullptr;
// dense layer does not need micro batching.
if (!moe_layer) {
for (auto& layer_input : micro_batch_layer_inputs) {
auto layer_outputs = forwardGptLayer(layer_input, i, nullptr);
layer_input.hidden = move(layer_outputs.hidden);
}
continue;
}
for (size_t micro_batch_idx = 0; micro_batch_idx < micro_batch_layer_inputs.size(); ++micro_batch_idx) {
auto& layer_input = micro_batch_layer_inputs[micro_batch_idx];
auto& last_layer_defered_params = last_layer_defered_params_vec[micro_batch_idx];
auto last_layer_moe_ret = device_->stealMoEInsertionRet();
RUNTIME_ASSERT_OP_ARG(
bool(last_layer_defered_params.shared_expert_output) == bool(last_layer_moe_ret),
"moe insertion return should only be null if no previous layer.");
if (last_layer_defered_params.combine_output) {
last_layer_defered_params.combine_output = nullopt;
}
last_layer_defered_params.combine_output = last_layer_moe_ret
? std::optional<rtp_llm::MoeCombineOutput>(last_layer_moe_ret->combine_output)
: nullopt;
auto ep_input = forwardAttentionAndMoeGate(layer_input, last_layer_defered_params, i, micro_batch_idx);
// set moe insertion params
device_->setMoEInsertion(MoEInsertionParams(
ep_input.dispatch_output,
ep_input.moe_ffn_params,
std::make_shared<MoeGateSelectOutput>(ep_input.gate_output),
ep_input.hidden->shape()[0]
));
last_layer_defered_params.residual = ep_input.residual;
last_layer_defered_params.post_ffn_layernorm_weights = layer.post_ffn_layernorm;
// call shared
auto shared_expert_output = device_->moeSharedExpert(ep_input.moe_ffn_params).hidden_states;
last_layer_defered_params.shared_expert_output = shared_expert_output;
}
}
// deal with last layer
auto mb0_moe_insertion_ret = device_->stealMoEInsertionRet();
last_layer_defered_params_vec[0].combine_output = nullopt;
last_layer_defered_params_vec[0].combine_output = mb0_moe_insertion_ret->combine_output;
// last layer last micro batch
device_->computeInsertedMoE();
auto moe_insertion_ret = device_->stealMoEInsertionRet();
moe_insertion_ret->combine_output.comm_barrier_hook->hook_sync();
last_layer_defered_params_vec.back().combine_output = nullopt;
last_layer_defered_params_vec.back().combine_output = move(moe_insertion_ret->combine_output);
for (size_t micro_batch_idx = 0; micro_batch_idx < micro_batch_layer_inputs.size(); ++micro_batch_idx) {
auto& layer_input = micro_batch_layer_inputs[micro_batch_idx];
auto& last_layer_defered_params = last_layer_defered_params_vec[micro_batch_idx];
auto output = device_->gatherCombineOutput(last_layer_defered_params.combine_output.value()).hidden_states;
auto ffn_layernorm_output = device_->layernorm({
output,
nullptr,
rtp_llm::mayGetRef(last_layer_defered_params.post_ffn_layernorm_weights),
rtp_llm::mayGetRef(last_layer_defered_params.residual),
rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
nullopt,
1.0f,
description_.layernorm_eps,
true,
description_.post_layernorm,
description_.norm_type,
QScheme::NoQuantize
});
layer_input.hidden = move(ffn_layernorm_output.output);
printBufferData(*layer_input.hidden, "mb_" + to_string(micro_batch_idx) + "_final_hidden");
}
return micro_batch_layer_inputs;
}
GptLayerOutputs GptModel::forwardMicroBatchedLayers(
const GptLayerInputs& layer_inputs, const GptModelInputs& inputs)
{
std::vector<GptLayerInputs> micro_batch_layer_inputs;
for (auto& micro_batch_input : layer_inputs.micro_batch_inputs) {
micro_batch_layer_inputs.push_back({
micro_batch_input.hidden,
micro_batch_input.pre_decoder_residual,
micro_batch_input.attention_common_inputs,
layer_inputs.dtype
});
}
if (device_props_.enable_layer_micro_batch == MicroBatchType::DS_PREFILL) {
micro_batch_layer_inputs = forwardPrefillMicroBatchedLayers(micro_batch_layer_inputs);
} else if (device_props_.enable_layer_micro_batch == MicroBatchType::DS_DECODE) {
micro_batch_layer_inputs = forwardDecodeMicroBatchedLayers(micro_batch_layer_inputs);
} else {
RUNTIME_ASSERT_OP_ARG(false, "micro batch type %d is not supported", int(device_props_.enable_layer_micro_batch));
}
const auto& hidden = layer_inputs.hidden;
size_t copy_from_token_idx = 0;
if (!layer_inputs.micro_batch_inputs[1].fake) {
for (size_t i = 0; i < micro_batch_layer_inputs.size(); ++i) {
const auto& micro_batch_hidden = micro_batch_layer_inputs[i].hidden;
const auto micro_batch_token_num = micro_batch_hidden->shape()[0];
const auto target_hidden = hidden->slice(copy_from_token_idx, micro_batch_token_num);
device_->copy({*target_hidden, *micro_batch_hidden});
copy_from_token_idx += micro_batch_token_num;
}
printBufferData(*hidden, "micor_batched_final_hidden");
} else {
device_->copy({*hidden, *(micro_batch_layer_inputs[0].hidden)});
printBufferData(*hidden, "non-micor_batched_final_hidden");
}
return {hidden, nullptr};
}
GptLayerOutputs GptModel::forwardGptLayer(
GptLayerInputs inputs,
const int32_t layer_id,
rtp_llm::lora::LoraModelInputPtr lora_model_input)
{
auto pre_decoder_residual = inputs.pre_decoder_residual;
auto attention_block_output = forwardAttentionBlock(inputs, layer_id, lora_model_input);
auto hidden = move(attention_block_output.hidden);
auto residual = move(attention_block_output.residual);
auto residual2 = move(attention_block_output.residual2);
const auto& layer = weights_.layers[layer_id];
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_input");
bool enable_sp = inputs.enable_sp;
size_t rank_pad_token_num = enable_sp ? inputs.pad_token_num / device_props_.tp_size : hidden->shape()[0];
auto ffn_output_buf = device_->allocateBuffer({inputs.dtype, {rank_pad_token_num, hidden->shape()[1]}}, {"ffn_out_buf"});
if (!enable_sp) {
// Note: for custom all reduce, prepareAllReduce will replace the original attn_out_buf with
// a new custom_ar_comm buffer. Here we must make sure that attn_out_buf is not released or replaced by
// other buffer before the actual allreduce operations. Otherwise, it will raise an error in custom ar.
ffn_output_buf = device_->prepareAllReduce({std::move(ffn_output_buf), ReduceOp::Sum}).buffer;
}
auto ffn_layer_params = FfnLayerParams({*hidden, description_.ffn_conf,
layer.ffn_weights,
device_props_.ffn_fuse_add_residual ? (OptionalConstBufferRef)*residual : nullopt,
description_.act_qscheme,
std::move(ffn_output_buf),
enable_sp,
layer_num_ == 1});
// expert stats
prepareExpertStats(layer_id, ffn_layer_params);
if (lora_model_input) {
ffn_layer_params.lora_input = lora_model_input->getFfnLayerLoraInput(layer_id);
}
auto ffn_output = device_->ffnLayer(ffn_layer_params);
hidden = ffn_output.hidden_states;
if (device_props_.ffn_tp_size > 1 && !layer.ffn_weights.moe_gating_weight && !enable_sp) {
// Note: for custom all reduce, allReduce will allocate a new buffer and replace the original attn_hidden with it
auto wrapper = DevicePerfWrapper(device_, "post_ffn_all_reduce, sizeBytes=%ld", (long)hidden->sizeBytes());
hidden = device_->allReduce({std::move(hidden), ReduceOp::Sum, false, ParallelMode::FFN_TP}).buffer;
}
if (residual_scale_) {
hidden = device_->multiply({*residual_scale_, *hidden});
}
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_output");
// TODO: maybe move this layernorm to ffn layer
auto ffn_layernorm_output = device_->layernorm(LayernormParams(hidden,
pre_decoder_residual,
rtp_llm::mayGetRef(layer.post_ffn_layernorm),
device_props_.ffn_fuse_add_residual ? nullopt : (OptionalConstBufferRef)*residual,
(residual2 == nullptr) ? nullopt : (OptionalConstBufferRef)*residual2,
rtp_llm::mayGetRef(WEIGHT_MAY_GET_BIAS(layer.ffn_weights.down_weight)),
1.0f,
description_.layernorm_eps,
true,
description_.post_layernorm,
description_.norm_type,
((layer_id == layer_num_ - 1) || (!layer.post_ffn_layernorm)) ? QScheme::NoQuantize: description_.act_qscheme));
hidden = std::move(ffn_layernorm_output.output);
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_final_hidden");
return {hidden, pre_decoder_residual};
}
AttentionBlockOutputs GptModel::forwardAttentionBlock(
const GptLayerInputs& inputs,
const int32_t layer_id,
rtp_llm::lora::LoraModelInputPtr lora_model_input,
const LastLayerDeferedParams& last_layer_defered_params)
{
auto hidden = inputs.hidden;
auto pre_decoder_residual = inputs.pre_decoder_residual;
auto attention_common_inputs = move(inputs.attention_common_inputs);
DevicePerfWrapper wrapper(device_, "attention_block_layer_" + std::to_string(layer_id) + "_bs_" + std::to_string(hidden->shape()[0]));
if (last_layer_defered_params.combine_output) {
printBufferData(*(last_layer_defered_params.combine_output.value().all_output), "layer_" + to_string(layer_id - 1) + "_combine_output_defered");
hidden = device_->gatherCombineOutput(last_layer_defered_params.combine_output.value()).hidden_states;
}
attention_common_inputs.layer_id = layer_id;
const auto& layer = weights_.layers[layer_id];
bool enable_sp = inputs.enable_sp;
// here hidden->dtype maybe int8, so use dytpe of embedding lookup result instead
size_t rank_pad_token_num = enable_sp ? inputs.pad_token_num / device_props_.tp_size : hidden->shape()[0];
BufferPtr attn_out_buf = device_->allocateBuffer({inputs.dtype, {rank_pad_token_num, hidden->shape()[1]}}, {"attn_out_buf"});
if (!enable_sp) {
// Note: for custom all reduce, prepareAllReduce will replace the original attn_out_buf with
// a new custom_ar_comm buffer. Here we must make sure that attn_out_buf is not released or replaced by
// other buffer before the actual allreduce operations. Otherwise, it will raise an error in custom ar.
attn_out_buf = device_->prepareAllReduce({std::move(attn_out_buf), ReduceOp::Sum}).buffer;
}
auto residual = pre_decoder_residual ? pre_decoder_residual : hidden;
printBufferData(*residual, "in residual");
BufferPtr residual2 = nullptr;
BufferPtr hidden_to_slice = nullptr; // for sp and overlap comm type 2
if (layer.pre_layernorm) {
// TODO(wangyin.yx): fuse this clone branch into layernorm(rmsnorm)
residual = last_layer_defered_params.residual ? device_->allocateBufferLike(*hidden, AllocationType::DEVICE, {"residual"})
: device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
int m_split = device_props_.m_split;
size_t overlap_comm_type = device_props_.overlap_comm_type;
auto pre_layernorm_output = device_->layernorm(LayernormParams(hidden,
residual,
*layer.pre_layernorm,
rtp_llm::mayGetRef(last_layer_defered_params.residual),
rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
std::nullopt,
0.f,
description_.layernorm_eps,
(enable_sp && overlap_comm_type == 2) ? false : true,
false,
description_.norm_type,
description_.act_qscheme,
layer_id > 0 ? true: false,
false));
if (enable_sp && layer_id == 0) {
if (overlap_comm_type == 1 && m_split > 0) {
vector<int> selected_indices;
selected_indices.reserve(rank_pad_token_num);
size_t m = inputs.pad_token_num;
size_t m_chunk = m / m_split;
if (m > 128) {
m_chunk = (m / m_split + 127) & ~127;
}
size_t tp_rank = device_props_.tp_rank;
size_t round = m_chunk / device_props_.tp_size;
size_t offset = tp_rank * round;
for (size_t i = 0; i < rank_pad_token_num; i++) {
selected_indices.push_back( (i / round) * m_chunk + i % round + offset);
}
// printBufferData(*vector2Buffer(selected_indices), "selected_indices");
residual = device_->select({*residual, *device_->clone({*vector2Buffer(selected_indices)})});
} else {
hidden_to_slice = residual;
residual = residual->slice(rank_pad_token_num * device_props_.tp_rank, rank_pad_token_num);
}
}
hidden = std::move(pre_layernorm_output.output);
} else if (last_layer_defered_params.residual || last_layer_defered_params.shared_expert_output) {
// NOTE(wangyin): this branch is not used for now, might be errornous
residual = device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
auto prev_ffn_layernorm_output = device_->layernorm({
hidden,
nullptr,
std::nullopt, // post_ffn_layernorm_weights
rtp_llm::mayGetRef(last_layer_defered_params.residual),
rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
});
hidden = std::move(prev_ffn_layernorm_output.output);
}
if (k_cache_buffer_ && attention_common_inputs.kv_cache) {
attention_common_inputs.kv_cache->k_cache_buffer = k_cache_buffer_->index(layer_id);
attention_common_inputs.kv_cache->v_cache_buffer = v_cache_buffer_->index(layer_id);
if (k_scale_buffer_) {
attention_common_inputs.kv_cache->k_scale_buffer = k_scale_buffer_->index(layer_id);
attention_common_inputs.kv_cache->v_scale_buffer = v_scale_buffer_->index(layer_id);
}
}
if (lora_model_input) {
attention_common_inputs.lora_input = lora_model_input->getAttentionLayerLoraInput(layer_id);
}
AttentionLayerOutput attn_output;
auto attn_params = AttentionLayerParams({
layer_id,
*hidden,
move(attn_out_buf),
description_.attention_conf,
layer.self_attention_weights,
attention_common_inputs,
device_props_.attn_fuse_add_residual ? (OptionalConstBufferRef)*residual : nullopt,
{description_.layernorm_eps, description_.norm_type},
description_.act_qscheme,
enable_sp,
inputs.pad_token_num
});
if (description_.attention_conf.use_mla && device_->mla_ops_type != rtp_llm::MlaOpsType::MHA) {
attn_output = device_->mlaAttentionLayer(attn_params);
} else {
attn_output = device_->attentionLayer(attn_params);
}
auto attn_hidden = std::move(attn_output.hidden_states);
if (device_props_.tp_size > 1 && !enable_sp) {
// Note: for custom all reduce, allReduce will allocate a new buffer and replace the original attn_hidden with it
auto wrapper = DevicePerfWrapper(device_, "allReduce, sizeBytes=%ld", (long)attn_hidden->sizeBytes());
attn_hidden = device_->allReduce({std::move(attn_hidden), ReduceOp::Sum}).buffer;
}
if (residual_scale_) {
attn_hidden = device_->multiply({*residual_scale_, *attn_hidden});
}
printBufferData(*attn_hidden, "layer_" + to_string(layer_id) + "_attn_output");
if (layer.post_layernorm) {
// attn_hidden = attn_hidden + residual
// hidden = layernorm(attn_hidden)
printBufferData(*residual, "before post layernorm residual");
auto post_layernorm_params = LayernormParams(attn_hidden,
attn_hidden,
rtp_llm::mayGetRef(layer.post_layernorm),
device_props_.attn_fuse_add_residual ? nullopt : (OptionalConstBufferRef)*residual,
nullopt,
rtp_llm::mayGetRef(layer.self_attention_weights.output_weight->bias),
0.f,
description_.layernorm_eps,
false,
description_.post_layernorm,
description_.norm_type,
description_.act_qscheme,
false,
true);
auto post_layernorm_output = device_->layernorm(post_layernorm_params);
hidden = std::move(post_layernorm_output.output);
attn_hidden = std::move(post_layernorm_output.before_norm_output);
residual = attn_hidden;
printBufferData(*residual, "after post layernorm residual");
} else {
residual2 = attn_hidden;
}
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_input");
return {hidden, residual, residual2};
}
EpFfnInputs GptModel::forwardAttentionAndMoeGate(
const GptLayerInputs& inputs,
LastLayerDeferedParams& last_layer_defered_params,
const int32_t layer_id,
const size_t micro_batch_idx)
{
auto hidden = inputs.hidden;
auto pre_decoder_residual = inputs.pre_decoder_residual;
const auto& layer = weights_.layers[layer_id];
DevicePerfWrapper wrapper(device_, "mb_forwardGptLayer_" + std::to_string(layer_id) + "_bs_" + std::to_string(hidden->shape()[0]));
auto attention_block_output = forwardAttentionBlock(inputs, layer_id, nullptr, last_layer_defered_params);
hidden = move(attention_block_output.hidden);
auto residual = move(attention_block_output.residual);
auto residual2 = move(attention_block_output.residual2);
printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_input");
auto ffn_output_buf = device_->allocateBuffer({inputs.dtype, hidden->shape()}, {"ffn_out_buf"});
auto ffn_layer_params = FfnLayerParams({*hidden, description_.ffn_conf,
layer.ffn_weights,
device_props_.ffn_fuse_add_residual ? (OptionalConstBufferRef)*residual : nullopt,
description_.act_qscheme,
std::move(ffn_output_buf)});
prepareExpertStats(layer_id, ffn_layer_params);
MoeGateSelectOutput gate_output = device_->moeGateSelect(ffn_layer_params);
RTP_LLM_LOG_DEBUG("call layer %ld micro batch ep dispatch batch size = %ld", layer_id, hidden->shape()[0]);
BufferPtr shared_expert_output = nullptr;
printBufferData(ffn_layer_params.input, "layer_" + to_string(layer_id) + "_ep_dispatch_input");
printBufferData(*gate_output.expert_ids, "layer_" + to_string(layer_id) + "_expert_ids");
if (gate_output.expert_scales) {
printBufferData(*gate_output.expert_scales, "layer_" + to_string(layer_id) + "_expert_scales");
}
if (device_props_.enable_layer_micro_batch == MicroBatchType::DS_PREFILL) {
// shared expert when overlapping combine
if (micro_batch_idx == 0) {
shared_expert_output = device_->moeSharedExpert(ffn_layer_params).hidden_states;
last_layer_defered_params.shared_expert_output = shared_expert_output;
}
if (last_comm_hook_) {
last_comm_hook_->hook_sync();
last_comm_hook_ = nullptr;
}
} else {
// call combine hook sync
const auto& previous_moe_ret = device_->getMoEInsertionRet();
if (previous_moe_ret && previous_moe_ret->combine_output.comm_barrier_hook) {
previous_moe_ret->combine_output.comm_barrier_hook->hook_sync();
}
}
MoeDispatchOutput dispatched_output = device_->epDispatch({
ffn_layer_params.input,
*gate_output.expert_ids,
*gate_output.expert_scales,
description_.ffn_conf.moe_configs.value(),
device_props_.enable_comm_overlap,
description_.act_qscheme,
ffn_layer_params.expert_stats
});
printBufferData(*dispatched_output.hidden, "layer_" + to_string(layer_id) + "_dispatch_output");
RTP_LLM_LOG_DEBUG("call layer %ld micro batch ep dispatch done.", layer_id, hidden->shape()[0]);
if (device_props_.enable_layer_micro_batch == MicroBatchType::DS_PREFILL) {
if (dispatched_output.comm_barrier_hook) {
last_comm_hook_ = move(dispatched_output.comm_barrier_hook);
} else {
RTP_LLM_LOG_DEBUG("no dispatch barrier for layer %ld, micro batch %ld", layer_id, inputs.micro_batch_inputs.size());
}
}
return {hidden, residual, shared_expert_output, move(ffn_layer_params), move(gate_output), move(dispatched_output)};
}
GptLayerOutputs GptModel::forwardMoeFfn(const GptLayerOutputs& inputs, const int32_t layer_id) {
return inputs;
}
GptModelOutputs GptModel::forwardPostLayers(
rtp_llm::BufferPtr input,
const bool has_context_request,
const bool need_all_logits,
const rtp_llm::BufferPtr lm_output_indexes,
bool enable_sp,
size_t token_num,
const GptModelInputs& inputs)
{
DevicePerfWrapper wrapper(device_, "forwardPostLayers");
BufferPtr all_gather_output = nullptr;
if (enable_sp && device_props_.tp_size > 1) {
all_gather_output = device_->allocateBuffer({input->type(), {input->shape()[0] * device_props_.tp_size, input->shape()[1]}}, {"all_gather_output"});
size_t m = all_gather_output->shape()[0];
int m_split = device_props_.m_split;
size_t overlap_comm_type = device_props_.overlap_comm_type;
if (overlap_comm_type == 1 && m_split > 0) {
size_t token_idx = 0;
size_t ag_token_idx = 0;
size_t m_chunk = m / m_split;
if (m > 128) {
m_chunk = (m / m_split + 127) & ~127;
}
while (token_idx < m) {
const auto micro_batch_tokens = std::min(m - token_idx, m_chunk);
const auto ag_micro_batch_tokens = micro_batch_tokens / device_props_.tp_size;
auto micro_batch_recv_buffer = all_gather_output->slice(token_idx, micro_batch_tokens);
auto micro_ag_send_buffer = input->slice(ag_token_idx, ag_micro_batch_tokens);
device_->allGather({{micro_batch_recv_buffer}, ParallelMode::TP, {micro_ag_send_buffer}, false});
token_idx += micro_batch_tokens;
ag_token_idx += ag_micro_batch_tokens;
}
} else {
device_->allGather({{all_gather_output}, ParallelMode::TP, {input}, false});
}
size_t pad_mod_num = device_props_.tp_size * max((size_t)1, device_props_.m_split);
if (token_num % pad_mod_num != 0) {
input = device_->clone({all_gather_output->view(0, token_num), AllocationType::DEVICE});
} else {
input = all_gather_output;
}
}
auto hidden = input;
if (weights_.final_layernorm) {
auto final_layernorm = device_->layernorm(LayernormParams(hidden,
nullptr,
rtp_llm::mayGetRef(weights_.final_layernorm),
nullopt,
nullopt,
nullopt,
0.f,
description_.layernorm_eps,
true,
false,
description_.norm_type));
hidden = std::move(final_layernorm.output);
}
printBufferData(*hidden, "final_hidden");
if (device_->getDeviceProperties().is_mtp) {
const auto decoder_batch_size = inputs.sequence_lengths->shape()[0];
const auto context_batch_size = inputs.input_lengths->shape()[0] - decoder_batch_size;
device_->writeHiddenStatesStore({inputs.pd_separation,
inputs.warmup,
context_batch_size,
decoder_batch_size,
inputs.request_pd_separation,
inputs.request_id,
hidden,
lm_output_indexes});
}
const auto& lm_head = weights_.lm_head;
if (lm_head) {
// gen last token hidden
printBufferData(*lm_output_indexes, "lm_output_indexes");
auto last_hidden = has_context_request && !need_all_logits
? device_->select({*hidden, *device_->clone({*lm_output_indexes})})
: hidden;
printBufferData(*last_hidden, "last_hidden");
auto logits = device_->gemm(GemmParams(
*last_hidden, *(lm_head->kernel), nullopt, nullptr,
rtp_llm::DataType::TYPE_FP32, TransposeOperation::NONE, TransposeOperation::TRANSPOSE));
printBufferData(*logits, "logits");
if (device_props_.tp_size > 1) {
logits = tpSyncEmbeddingOrLogits(logits);
}
// TODO(xinfei.sxf) calculate softmax_result
rtp_llm::BufferPtr softmax_result;
// logits is too big, tmp not print default
// printBufferData(*logits, "logits");
if (need_all_logits) {
auto last_logits = device_->select({*logits, *device_->clone({*lm_output_indexes})});
return {std::move(last_logits), std::move(last_hidden), std::move(hidden), std::move(logits), std::move(softmax_result)};
}
return {std::move(logits), std::move(last_hidden), std::move(hidden), nullptr, std::move(softmax_result)};
} else {
return {nullptr, nullptr, std::move(hidden)};
}
}
GptModelOutputs GptModel::forward(const GptModelInputs& inputs) {
cleanExpertStats();
auto layer_inputs = forwardPreLayers(inputs);
GptLayerOutputs layer_outputs;
if (int(device_props_.enable_layer_micro_batch)) {
RUNTIME_ASSERT_OP_ARG(layer_inputs.micro_batch_inputs.size(), "no micro batch inputs when enabled");
layer_outputs = forwardMicroBatchedLayers(layer_inputs, inputs);
} else {
for (int32_t i = 0; i < layer_num_; ++i) {
layer_outputs = forwardGptLayer(layer_inputs, i, inputs.lora_model_input);
layer_inputs.hidden = layer_outputs.hidden;
layer_inputs.pre_decoder_residual = layer_outputs.pre_decoder_residual;
}
}
auto outputs = forwardPostLayers(
layer_outputs.hidden,
inputs.input_lengths->shape()[0] != inputs.sequence_lengths->shape()[0],
inputs.need_all_logits,
inputs.lm_output_indexes,
layer_inputs.enable_sp,
layer_inputs.token_num,
inputs);
// make sure cpu buffers out lives gpu exec
outputs.captured_values = make_shared<GptLayerInputs>(layer_inputs);
return outputs;
}
void GptModel::prepareExpertStats(const size_t layer_id,
rtp_llm::FfnLayerParams& ffn_layer_params) {
OptionalExpertStats layer_expert_stats = nullopt;
if (overall_expert_stats_.log_exp_num != 0) {
layer_expert_stats = ExpertStats({layer_id,
overall_expert_stats_.ep_size,
overall_expert_stats_.log_exp_num,
overall_expert_stats_.phy_exp_num,
overall_expert_stats_.stats_buf});
}
ffn_layer_params.expert_stats = layer_expert_stats;
}
void GptModel::cleanExpertStats() {
if (overall_expert_stats_.log_exp_num != 0) {
device_->cleanMoeExpertStates(overall_expert_stats_);
}
}
void dpAndTpSyncModelInputs(GptModelInputs &inputs, rtp_llm::DeviceBase* device) {
if (device->getDeviceProperties().tp_size <= 1) {
return;
}
const size_t shape_hints_size = GptModelInputIndex::gptModelInputLength;
auto shape_hints = device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {shape_hints_size}, rtp_llm::AllocationType::HOST});
auto shape_hints_ptr = shape_hints->data<int32_t>();
shape_hints_ptr[GptModelInputIndex::comboTokens] = inputs.combo_tokens.get() ? inputs.combo_tokens->size() : 0;
shape_hints_ptr[GptModelInputIndex::inputLengths] = inputs.input_lengths.get() ? inputs.input_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::sequenceLengths] = inputs.sequence_lengths.get() ? inputs.sequence_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::prefixLengths] = inputs.prefix_lengths.get() ? inputs.prefix_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch] = inputs.kv_cache_block_id.get() ? inputs.kv_cache_block_id->shape()[1] : 0;
shape_hints_ptr[GptModelInputIndex::lmOutputIndexes] = inputs.lm_output_indexes.get() ? inputs.lm_output_indexes->size() : 0;
shape_hints_ptr[GptModelInputIndex::comboPositionIds] = inputs.combo_position_ids.get() ? inputs.combo_position_ids->size() : 0;
shape_hints_ptr[GptModelInputIndex::loraIds] = inputs.lora_ids.get() ? inputs.lora_ids->size() : 0;
shape_hints_ptr[GptModelInputIndex::loraInputLengths] = inputs.lora_input_lengths.get() ? inputs.lora_input_lengths->size() : 0;
shape_hints_ptr[GptModelInputIndex::textTokensMask] = inputs.text_tokens_mask.get() ? inputs.text_tokens_mask->size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs] = inputs.mm_features_locs.get() ? inputs.mm_features_locs->size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value().size() : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesSize] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? inputs.multimodal_features.value()[0]->shape()[1] : 0;
shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype] = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum] ? (std::uint8_t)inputs.multimodal_features.value()[0]->type() : 0;
shape_hints_ptr[GptModelInputIndex::needAllLogits] = inputs.need_all_logits;
shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] = inputs.last_hidden_states.get() ? inputs.last_hidden_states->size() : 0;
shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype] = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates] ? (std::uint8_t)inputs.last_hidden_states->type() : 0;
device->broadcast({{shape_hints}, 0});
device->syncCommunication(false);
device->syncAndCheck();
// multimodal features shape broadcast
rtp_llm::BufferPtr mm_features_shape;
int32_t* mm_features_shape_ptr = nullptr;
inputs.need_all_logits = shape_hints_ptr[GptModelInputIndex::needAllLogits];
const size_t mm_features_num = shape_hints_ptr[GptModelInputIndex::mmFeaturesNum];
if (mm_features_num) {
mm_features_shape =
device->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesNum]}, rtp_llm::AllocationType::HOST});
mm_features_shape_ptr = mm_features_shape->data<int32_t>();
for (auto i = 0; i < mm_features_num; ++i) {
mm_features_shape_ptr[i] = inputs.multimodal_features.has_value() ? inputs.multimodal_features.value()[i]->shape()[0] : 0;
}
device->broadcast({{mm_features_shape}, 0});
device->syncCommunication(false);
device->syncAndCheck();
}
auto max_blocks = (size_t)shape_hints_ptr[GptModelInputIndex::maxBlocksPerBatch];
auto combo_position_ids_size = shape_hints_ptr[GptModelInputIndex::comboPositionIds];
auto text_tokens_mask_size = shape_hints_ptr[GptModelInputIndex::textTokensMask];
auto mm_features_locs_size = shape_hints_ptr[GptModelInputIndex::mmFeaturesLocs];
auto hidden_states_size = shape_hints_ptr[GptModelInputIndex::mtpHiddenStates];
if (device->getDeviceProperties().tp_rank) {
auto context_batch_size = (size_t)shape_hints_ptr[GptModelInputIndex::prefixLengths];
inputs.combo_tokens = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::comboTokens]}, rtp_llm::AllocationType::HOST});
inputs.input_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths]}, rtp_llm::AllocationType::HOST});
inputs.sequence_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::sequenceLengths]}, rtp_llm::AllocationType::HOST});
inputs.prefix_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {context_batch_size}, rtp_llm::AllocationType::HOST});
if (max_blocks != 0) {
inputs.kv_cache_block_id = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32,
{(size_t)shape_hints_ptr[GptModelInputIndex::inputLengths], max_blocks}, rtp_llm::AllocationType::HOST});
inputs.cache_keys = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {context_batch_size, max_blocks}, rtp_llm::AllocationType::HOST});
}
inputs.request_id = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {context_batch_size}, rtp_llm::AllocationType::HOST});
inputs.request_pd_separation = device->allocateBuffer(
{rtp_llm::DataType::TYPE_BOOL, {context_batch_size}, rtp_llm::AllocationType::HOST});
inputs.lm_output_indexes = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::lmOutputIndexes]}, rtp_llm::AllocationType::HOST});
if (combo_position_ids_size) {
inputs.combo_position_ids = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)combo_position_ids_size}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::loraIds]) {
inputs.lora_ids = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraIds]}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::loraInputLengths]) {
inputs.lora_input_lengths = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)shape_hints_ptr[GptModelInputIndex::loraInputLengths]}, rtp_llm::AllocationType::HOST});
}
if (shape_hints_ptr[GptModelInputIndex::mtpHiddenStates]) {
auto hidden_states_dim0 = (size_t)shape_hints_ptr[GptModelInputIndex::comboTokens];
auto hidden_states_dim1 = (size_t)hidden_states_size / hidden_states_dim0;
RTP_LLM_CHECK(hidden_states_size % hidden_states_dim0 == 0);
inputs.last_hidden_states = device->allocateBuffer(
{(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mtpHiddenStatesDtype],
{hidden_states_dim0, hidden_states_dim1},
rtp_llm::AllocationType::DEVICE});
}
if (text_tokens_mask_size) {
inputs.text_tokens_mask = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)text_tokens_mask_size}, rtp_llm::AllocationType::HOST});
}
if (mm_features_locs_size) {
inputs.mm_features_locs = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)mm_features_locs_size}, rtp_llm::AllocationType::HOST});
}
if (mm_features_num) {
std::vector<rtp_llm::BufferPtr> mm_features;
for (auto mm_index = 0; mm_index < mm_features_num; ++mm_index) {
mm_features.emplace_back(
device->allocateBuffer(
{(rtp_llm::DataType)shape_hints_ptr[GptModelInputIndex::mmFeaturesDtype],
{(size_t)mm_features_shape_ptr[mm_index], (size_t)shape_hints_ptr[GptModelInputIndex::mmFeaturesSize]},
rtp_llm::AllocationType::DEVICE}));
}
inputs.multimodal_features = std::move(mm_features);
}
}
std::vector<rtp_llm::BufferPtr> buffers;
buffers.emplace_back(inputs.combo_tokens);
buffers.emplace_back(inputs.input_lengths);
buffers.emplace_back(inputs.sequence_lengths);
buffers.emplace_back(inputs.prefix_lengths);
if (max_blocks) {
buffers.emplace_back(inputs.kv_cache_block_id);
buffers.emplace_back(inputs.cache_keys);
}
buffers.emplace_back(inputs.request_id);
buffers.emplace_back(inputs.request_pd_separation);
buffers.emplace_back(inputs.lm_output_indexes);
if (combo_position_ids_size) {
buffers.emplace_back(inputs.combo_position_ids);
}
buffers.emplace_back(inputs.lora_ids);
buffers.emplace_back(inputs.lora_input_lengths);
if (text_tokens_mask_size) {
buffers.emplace_back(inputs.text_tokens_mask);
}
if (mm_features_locs_size) {
buffers.emplace_back(inputs.mm_features_locs);
}
if (mm_features_num) {
for (auto& mm_feature: inputs.multimodal_features.value()) {
buffers.emplace_back(mm_feature);
}
}
if (hidden_states_size) {
buffers.emplace_back(inputs.last_hidden_states);
}
device->broadcast({buffers, 0});
device->syncAndCheck();
}
} // namespace rtp_llm