in maga_transformer/cpp/models/GptModel.cc [403:516]
GptLayerInputs GptModel::forwardPreLayers(const GptModelInputs& inputs) {
DevicePerfWrapper wrapper(device_, "forwardPreLayers");
bool enable_sp = device_->getDeviceProperties().enable_sp;
size_t token_num = inputs.combo_tokens->shape()[0];
size_t pad_token_num = token_num;
size_t pad_mod_num = device_props_.tp_size * max((size_t)1, device_props_.m_split);
if (token_num <= pad_mod_num) {
enable_sp = false;
}
if (enable_sp && token_num % pad_mod_num != 0) {
pad_token_num = token_num + (pad_mod_num - token_num % pad_mod_num);
BufferPtr combo_tokens = inputs.combo_tokens;
BufferPtr pad_combo_tokens = device_->allocateBuffer({combo_tokens->type(), {pad_token_num}, AllocationType::HOST},{"pad_combo_tokens"});
device_->bufMemset(*pad_combo_tokens, 0);
device_->copy({pad_combo_tokens->view(0, token_num), *combo_tokens});
inputs.combo_tokens = pad_combo_tokens;
printBufferData(*combo_tokens, {"combo_tokens"});
printBufferData(*pad_combo_tokens, {"pad_combo_tokens"});
}
const auto combo_tokens = device_->clone(
{*inputs.combo_tokens, AllocationType::DEVICE, {"combo_tokens"}});
const auto& embedding_table = weights_.embedding->kernel;
const BufferPtr combo_position_ids = inputs.combo_position_ids ? device_->clone({*inputs.combo_position_ids}): nullptr;
const BufferPtr combo_tokens_type_ids = inputs.combo_tokens_type_ids ? device_->clone({*inputs.combo_tokens_type_ids}): nullptr;
const BufferPtr text_tokens_mask = inputs.multimodal_features ?
device_->clone({*inputs.text_tokens_mask, AllocationType::DEVICE, {"text_tokens_mask"}}) : nullptr;
const BufferPtr mm_feature_locs = inputs.mm_features_locs ? inputs.mm_features_locs: nullptr;
// word embedding lookup
auto hidden = device_->embeddingLookup({
*combo_tokens, *embedding_table, description_.input_embedding_scalar,
text_tokens_mask ? (OptionalConstBufferRef)*text_tokens_mask : nullopt,
combo_position_ids ? (OptionalConstBufferRef)*combo_position_ids: nullopt,
weights_.position_encoding ? (OptionalConstBufferRef)*weights_.position_encoding->kernel: nullopt,
combo_tokens_type_ids ? (OptionalConstBufferRef)*combo_tokens_type_ids: nullopt,
weights_.token_type_embedding ? (OptionalConstBufferRef)*weights_.token_type_embedding->kernel: nullopt});
if (residual_scale_fp32_ && residual_scale_->type() != hidden->type()) {
residual_scale_ = device_->convert({residual_scale_fp32_, hidden->type()});
}
if (device_props_.tp_size > 1) {
hidden = tpSyncEmbeddingOrLogits(hidden);
}
hidden = embeddingPost(hidden, inputs);
auto hidden_dtype = hidden->type();
auto attn_dtype = hidden_dtype;
if (description_.act_qscheme == QScheme::Qfp8PerTensor) {
attn_dtype = DataType::TYPE_QFP8_E4M3;
}
// pre layernorm
BufferPtr pre_decoder_residual = nullptr;
if (description_.act_qscheme != QScheme::NoQuantize && weights_.pre_decoder_layernorm) {
pre_decoder_residual = device_->allocateBufferLike(*hidden);
}
printBufferData(*hidden, "before decoder layernorm hidden");
if (weights_.pre_decoder_layernorm) {
auto decoder_input = device_->layernorm(LayernormParams(hidden,
pre_decoder_residual,
*weights_.pre_decoder_layernorm,
nullopt,
nullopt,
nullopt,
0.f,
description_.layernorm_eps,
true,
pre_decoder_residual != nullptr,
description_.norm_type,
description_.act_qscheme));
hidden = std::move(decoder_input.output);
}
if (inputs.multimodal_features) {
hidden = device_->multimodalEmbedding({
hidden,
(OptionalConstVecBufferPtrRef)inputs.multimodal_features,
mm_feature_locs ? (OptionalConstBufferRef)*mm_feature_locs: nullopt
});
}
printBufferData(*hidden, "input_hidden");
if (device_props_.overlap_comm_type == 2) {
const auto& layer0 = weights_.layers[0];
RTP_LLM_CHECK_WITH_INFO(description_.act_qscheme == QScheme::NoQuantize || description_.act_qscheme == QScheme::Qint8PerToken || description_.act_qscheme == Qfp8PerTensor,
"ring p2p overlap only supports bf16/fp16 or w8a8 or fp8 per block");
const size_t max_batch_seq_len = autil::EnvUtil::getEnv("MAX_CONTEXT_BATCH_SIZE", 1) * device_->initParams().max_seq_len;
const size_t attn_rs_hidden = layer0.self_attention_weights.output_weight->kernel->shape()[1];
const size_t ffn_rs_hidden = layer0.ffn_weights.down_weight->kernel->shape()[1];
const size_t attn_ag_hidden = layer0.self_attention_weights.qkv_weight->kernel->shape()[0];
const size_t ffn_ag_hidden = layer0.ffn_weights.gate_weight->kernel->shape()[0];
DataType rs_output_type = hidden->type();
DataType ag_input_type = attn_dtype;
bool enable_per_token_scale = description_.act_qscheme == QScheme::Qint8PerToken;
bool enable_ffn_tp = enable_sp && device_props_.ffn_tp_size > 1;
device_->prepareCommBuffer({max_batch_seq_len, attn_rs_hidden, ffn_rs_hidden, attn_ag_hidden, ffn_ag_hidden, rs_output_type, ag_input_type, enable_per_token_scale, enable_ffn_tp});
}
auto micro_batch_plan = planMicroBatches(inputs);
if (int(device_props_.enable_layer_micro_batch)) {
auto micro_batch_inputs = prepareMicroBatchInputs(
inputs, hidden, pre_decoder_residual, attn_dtype, micro_batch_plan);
return {move(hidden), move(pre_decoder_residual), AttentionCommonInputs(), hidden_dtype, micro_batch_inputs, enable_sp, token_num, pad_token_num};
} else {
// prepare resources for all layers
auto attention_common_inputs = prepareAttentionInputs(inputs, attn_dtype, combo_position_ids);
return {move(hidden), move(pre_decoder_residual), move(attention_common_inputs), hidden_dtype, {}, enable_sp, token_num, pad_token_num};
}
}