GptLayerInputs GptModel::forwardPreLayers()

in maga_transformer/cpp/models/GptModel.cc [403:516]


GptLayerInputs GptModel::forwardPreLayers(const GptModelInputs& inputs) {
    DevicePerfWrapper wrapper(device_, "forwardPreLayers");
    bool enable_sp = device_->getDeviceProperties().enable_sp;
    size_t token_num = inputs.combo_tokens->shape()[0];
    size_t pad_token_num = token_num;
    size_t pad_mod_num = device_props_.tp_size * max((size_t)1, device_props_.m_split);
    if (token_num <= pad_mod_num) {
        enable_sp = false;
    }
    if (enable_sp && token_num % pad_mod_num != 0) {
        pad_token_num = token_num + (pad_mod_num - token_num % pad_mod_num);
        BufferPtr combo_tokens = inputs.combo_tokens;
        BufferPtr pad_combo_tokens = device_->allocateBuffer({combo_tokens->type(), {pad_token_num}, AllocationType::HOST},{"pad_combo_tokens"});
        device_->bufMemset(*pad_combo_tokens, 0);
        device_->copy({pad_combo_tokens->view(0, token_num), *combo_tokens});
        inputs.combo_tokens = pad_combo_tokens;
        printBufferData(*combo_tokens, {"combo_tokens"});
        printBufferData(*pad_combo_tokens, {"pad_combo_tokens"});
    }
    const auto combo_tokens = device_->clone(
        {*inputs.combo_tokens, AllocationType::DEVICE, {"combo_tokens"}});

    const auto& embedding_table = weights_.embedding->kernel;

    const BufferPtr combo_position_ids = inputs.combo_position_ids ? device_->clone({*inputs.combo_position_ids}): nullptr;
    const BufferPtr combo_tokens_type_ids = inputs.combo_tokens_type_ids ? device_->clone({*inputs.combo_tokens_type_ids}): nullptr;

    const BufferPtr text_tokens_mask = inputs.multimodal_features ?
        device_->clone({*inputs.text_tokens_mask, AllocationType::DEVICE, {"text_tokens_mask"}}) : nullptr;
    const BufferPtr mm_feature_locs = inputs.mm_features_locs ? inputs.mm_features_locs: nullptr;

    // word embedding lookup
    auto hidden = device_->embeddingLookup({
            *combo_tokens, *embedding_table, description_.input_embedding_scalar,
            text_tokens_mask ? (OptionalConstBufferRef)*text_tokens_mask : nullopt,
            combo_position_ids ? (OptionalConstBufferRef)*combo_position_ids: nullopt,
            weights_.position_encoding ? (OptionalConstBufferRef)*weights_.position_encoding->kernel: nullopt,
            combo_tokens_type_ids ? (OptionalConstBufferRef)*combo_tokens_type_ids: nullopt,
            weights_.token_type_embedding ? (OptionalConstBufferRef)*weights_.token_type_embedding->kernel: nullopt});
    if (residual_scale_fp32_ && residual_scale_->type() != hidden->type()) {
        residual_scale_ = device_->convert({residual_scale_fp32_, hidden->type()});
    }

    if (device_props_.tp_size > 1) {
        hidden = tpSyncEmbeddingOrLogits(hidden);
    }

    hidden = embeddingPost(hidden, inputs);

    auto hidden_dtype = hidden->type();
    auto attn_dtype = hidden_dtype;
    if (description_.act_qscheme == QScheme::Qfp8PerTensor) {
        attn_dtype = DataType::TYPE_QFP8_E4M3;
    }

    // pre layernorm
    BufferPtr pre_decoder_residual = nullptr;
    if (description_.act_qscheme != QScheme::NoQuantize && weights_.pre_decoder_layernorm) {
        pre_decoder_residual = device_->allocateBufferLike(*hidden);
    }
    printBufferData(*hidden, "before decoder layernorm hidden");
    if (weights_.pre_decoder_layernorm) {
        auto decoder_input = device_->layernorm(LayernormParams(hidden,
                                                                pre_decoder_residual,
                                                                *weights_.pre_decoder_layernorm,
                                                                nullopt,
                                                                nullopt,
                                                                nullopt,
                                                                0.f,
                                                                description_.layernorm_eps,
                                                                true,
                                                                pre_decoder_residual != nullptr,
                                                                description_.norm_type,
                                                                description_.act_qscheme));
        hidden = std::move(decoder_input.output);
    }

    if (inputs.multimodal_features) {
        hidden = device_->multimodalEmbedding({
                hidden,
                (OptionalConstVecBufferPtrRef)inputs.multimodal_features,
                mm_feature_locs ? (OptionalConstBufferRef)*mm_feature_locs: nullopt
        });
    }

    printBufferData(*hidden, "input_hidden");

    if (device_props_.overlap_comm_type == 2) {
        const auto& layer0 = weights_.layers[0];
        RTP_LLM_CHECK_WITH_INFO(description_.act_qscheme == QScheme::NoQuantize || description_.act_qscheme == QScheme::Qint8PerToken || description_.act_qscheme == Qfp8PerTensor,
                "ring p2p overlap only supports bf16/fp16 or w8a8 or fp8 per block");
        const size_t max_batch_seq_len = autil::EnvUtil::getEnv("MAX_CONTEXT_BATCH_SIZE", 1) * device_->initParams().max_seq_len;
        const size_t attn_rs_hidden = layer0.self_attention_weights.output_weight->kernel->shape()[1];
        const size_t ffn_rs_hidden = layer0.ffn_weights.down_weight->kernel->shape()[1];
        const size_t attn_ag_hidden = layer0.self_attention_weights.qkv_weight->kernel->shape()[0];
        const size_t ffn_ag_hidden = layer0.ffn_weights.gate_weight->kernel->shape()[0];
        DataType rs_output_type = hidden->type();
        DataType ag_input_type = attn_dtype;
        bool enable_per_token_scale = description_.act_qscheme == QScheme::Qint8PerToken;
        bool enable_ffn_tp = enable_sp && device_props_.ffn_tp_size > 1;
        device_->prepareCommBuffer({max_batch_seq_len, attn_rs_hidden, ffn_rs_hidden, attn_ag_hidden, ffn_ag_hidden, rs_output_type, ag_input_type, enable_per_token_scale, enable_ffn_tp});
    }

    auto micro_batch_plan = planMicroBatches(inputs);
    if (int(device_props_.enable_layer_micro_batch)) {
        auto micro_batch_inputs = prepareMicroBatchInputs(
            inputs, hidden, pre_decoder_residual, attn_dtype,  micro_batch_plan);
        return {move(hidden), move(pre_decoder_residual), AttentionCommonInputs(), hidden_dtype, micro_batch_inputs, enable_sp, token_num, pad_token_num};
    } else {
        // prepare resources for all layers
        auto attention_common_inputs = prepareAttentionInputs(inputs, attn_dtype, combo_position_ids);
        return {move(hidden), move(pre_decoder_residual), move(attention_common_inputs), hidden_dtype, {}, enable_sp, token_num, pad_token_num};
    }
}