maga_transformer/cpp/models/GptModel.h (211 lines of code) (raw):

#pragma once #include "maga_transformer/cpp/core/Buffer.h" #include "maga_transformer/cpp/core/Event.h" #include "maga_transformer/cpp/devices/DeviceBase.h" #include "maga_transformer/cpp/devices/OpData.h" #include "maga_transformer/cpp/devices/Weights.h" #include "maga_transformer/cpp/stats/ExpertStats.h" #include "maga_transformer/cpp/cache/CacheManager.h" #include <string> #include <utility> namespace rtp_llm { struct GptModelDescription { rtp_llm::AttentionConfigs attention_conf; rtp_llm::FfnConfigs ffn_conf; rtp_llm::NormType norm_type; rtp_llm::QScheme act_qscheme = rtp_llm::QScheme::NoQuantize; double layernorm_eps = 1e-5; size_t vocab_size = 0; bool post_layernorm = false; double input_embedding_scalar = 1; double residual_scalar = 1; bool reverse_e_h_norm = false; }; struct GptModelInitParams { rtp_llm::DeviceBase* device; const rtp_llm::Weights weights; const GptModelDescription description; const std::optional<CacheManager::KVCacheBuffer> kv_cache_buffer; }; // A batch includes two parts: context batch and decoder batch. // context batch is request for initial word, decoder batch is request for incremental word. // ids and lengths are int32_t struct GptModelInputs { // input_lengths holds original input length for requests, // shape [decoder_batch_size + context_batch_size], int32 // sequence_lengths holds current sequence length for incremental decoding requests, // shape [decoder_batch_size], int32 mutable rtp_llm::BufferPtr combo_tokens; // [cumulated_seq_len] rtp_llm::BufferPtr input_lengths; // [batch_size] rtp_llm::BufferPtr sequence_lengths; // [decoder_batch_size] rtp_llm::BufferPtr lm_output_indexes; // [context_batch_size] rtp_llm::BufferPtr prefix_lengths; // [context_batch_size] rtp_llm::BufferPtr combo_tokens_type_ids; // [cumulated_seq_len] rtp_llm::BufferPtr combo_position_ids; // [cumulated_seq_len] // for mtp model rtp_llm::BufferPtr last_hidden_states; // for tp sync rtp_llm::BufferPtr lora_ids; // [batch_size] rtp_llm::BufferPtr lora_input_lengths; // [batch_size] // no need tp sync rtp_llm::lora::LoraModelInputPtr lora_model_input; rtp_llm::BufferPtr attention_mask; // [batch_size, seq_len, seq_len] rtp_llm::BufferPtr kv_cache_block_id; // [batch_size, block_nums], kv cache block block id std::optional<std::vector<rtp_llm::BufferPtr>> multimodal_features; // all features in gathered stream stored here rtp_llm::BufferPtr text_tokens_mask; // text part in multimodal input tokens [cumulated_seq_len] rtp_llm::BufferPtr mm_features_locs; // features index rtp_llm::BufferPtr request_id; // int64, [context_batch_size] rtp_llm::BufferPtr request_pd_separation; // bool, [context_batch_size] rtp_llm::BufferPtr cache_keys; // [context_batch_size] size_t k_block_size; size_t v_block_size; size_t scale_block_size; size_t seq_size_per_block; bool pd_separation = false; bool need_all_logits = false; bool warmup = false; public: std::string debugString() const; }; enum GptModelInputIndex : size_t{ comboTokens, inputLengths, sequenceLengths, prefixLengths, maxBlocksPerBatch, lmOutputIndexes, comboPositionIds, loraIds, loraInputLengths, textTokensMask, mmFeaturesLocs, mmFeaturesNum, // number of mm features mmFeaturesSize, // hidden_size of mm features mmFeaturesDtype, needAllLogits, mtpHiddenStates, mtpHiddenStatesDtype, gptModelInputLength }; void dpAndTpSyncModelInputs(GptModelInputs &inputs, rtp_llm::DeviceBase* device); struct GptModelOutputs { rtp_llm::BufferPtr logits; rtp_llm::BufferPtr hidden_states; rtp_llm::BufferPtr all_hidden_states; rtp_llm::BufferPtr all_logits; rtp_llm::BufferPtr softmax_result; mutable rtp_llm::BufferPtr scatter_logits; mutable rtp_llm::BufferPtr scatter_hidden_states; std::shared_ptr<void> captured_values; }; using LoraMap = std::unordered_map<std::string, rtp_llm::ConstBufferPtr>; struct GptLayerOutputs { rtp_llm::BufferPtr hidden; rtp_llm::BufferPtr pre_decoder_residual; }; struct MicroBatchInfo { size_t prefill_num; size_t decoder_num; }; struct MicroBatchPlan { bool enable = false; std::vector<MicroBatchInfo> batch_infos; }; struct LayerMicroBatchInputs { rtp_llm::BufferPtr hidden; rtp_llm::BufferPtr pre_decoder_residual; rtp_llm::AttentionCommonInputs attention_common_inputs; bool fake = false; }; struct GptLayerInputs { rtp_llm::BufferPtr hidden; rtp_llm::BufferPtr pre_decoder_residual; rtp_llm::AttentionCommonInputs attention_common_inputs; const rtp_llm::DataType dtype; std::vector<LayerMicroBatchInputs> micro_batch_inputs; bool enable_sp = false; size_t token_num = 0; size_t pad_token_num = 0; }; struct AttentionBlockOutputs { rtp_llm::BufferPtr hidden; rtp_llm::BufferPtr residual; rtp_llm::BufferPtr residual2; }; struct EpFfnInputs { rtp_llm::BufferPtr hidden; rtp_llm::BufferPtr residual; rtp_llm::BufferPtr shared_expert_output; rtp_llm::FfnLayerParams moe_ffn_params; rtp_llm::MoeGateSelectOutput gate_output; rtp_llm::MoeDispatchOutput dispatch_output; }; struct EpFfnOutputs { rtp_llm::BufferPtr hidden; rtp_llm::MoeCombineOutput combine_output; rtp_llm::DeviceHookPtr comm_barrier_hook; }; struct LastLayerDeferedParams { rtp_llm::BufferPtr residual; rtp_llm::BufferPtr shared_expert_output; std::optional<rtp_llm::MoeCombineOutput> combine_output; std::shared_ptr<const rtp_llm::LayerNormWeights> post_ffn_layernorm_weights; rtp_llm::DeviceHookPtr comm_barrier_hook; }; class GptModel { public: GptModel(const GptModelInitParams& params); virtual ~GptModel() {}; virtual GptModelOutputs forward(const GptModelInputs& inputs); protected: rtp_llm::AttentionCommonInputs prepareAttentionInputs( const GptModelInputs& inputs, rtp_llm::DataType attn_dtype, rtp_llm::BufferPtr combo_position_ids); MicroBatchPlan planMicroBatches(const GptModelInputs& inputs); std::vector<LayerMicroBatchInputs> prepareMicroBatchInputs( const GptModelInputs& model_inputs, const rtp_llm::BufferPtr& hidden, const rtp_llm::BufferPtr& pre_decoder_residual, const rtp_llm::DataType attn_dtype, const MicroBatchPlan& micro_batch_plan); virtual rtp_llm::BufferPtr embeddingPost(const rtp_llm::BufferPtr& hidden_states, const GptModelInputs& inputs); rtp_llm::BufferPtr tpSyncEmbeddingOrLogits(const rtp_llm::BufferPtr& buffer); GptLayerInputs forwardPreLayers(const GptModelInputs& inputs); GptLayerOutputs forwardGptLayer( GptLayerInputs inputs, const int32_t layer_id, rtp_llm::lora::LoraModelInputPtr lora_model_input); AttentionBlockOutputs forwardAttentionBlock( const GptLayerInputs& inputs, const int32_t layer_id, rtp_llm::lora::LoraModelInputPtr lora_model_input, const LastLayerDeferedParams& last_layer_defered_params = {}); // These methods are dedicated for moe ep micro batching GptLayerOutputs forwardMicroBatchedLayers(const GptLayerInputs& layer_inputs, const GptModelInputs& inputs); std::vector<GptLayerInputs> forwardPrefillMicroBatchedLayers(std::vector<GptLayerInputs> inputs); std::vector<GptLayerInputs> forwardDecodeMicroBatchedLayers(std::vector<GptLayerInputs> inputs); EpFfnInputs forwardAttentionAndMoeGate( const GptLayerInputs& inputs, LastLayerDeferedParams& last_layer_defered_params, const int32_t layer_id, const size_t micro_batch_idx); GptLayerOutputs forwardMoeFfn(const GptLayerOutputs& inputs, const int32_t layer_id); GptModelOutputs forwardPostLayers( const rtp_llm::BufferPtr hidden, const bool has_context_request, const bool need_all_logits, const rtp_llm::BufferPtr lm_output_indexes, bool enable_sp, size_t token_num, const GptModelInputs& inputs); void prepareExpertStats( const size_t layer_id, rtp_llm::FfnLayerParams& ffn_layer_params); void cleanExpertStats(); protected: rtp_llm::DeviceBase* device_; const rtp_llm::DeviceProperties device_props_; const size_t layer_num_; const GptModelDescription description_; rtp_llm::BufferPtr k_cache_buffer_; rtp_llm::BufferPtr v_cache_buffer_; rtp_llm::BufferPtr k_scale_buffer_; rtp_llm::BufferPtr v_scale_buffer_; rtp_llm::BufferPtr residual_scale_fp32_; rtp_llm::BufferPtr residual_scale_; rtp_llm::DeviceHookPtr last_comm_hook_; public: rtp_llm::Weights weights_; rtp_llm::OverallExpertStats overall_expert_stats_; }; } // namespace rtp_llm