maga_transformer/cpp/devices/Weights.h (114 lines of code) (raw):

#pragma once #include "maga_transformer/cpp/core/Buffer.h" #include <optional> #include <memory> #include <unordered_map> #include <thread> #include <shared_mutex> namespace rtp_llm { struct LayerNormWeights { ConstBufferPtr gamma = nullptr; ConstBufferPtr beta = nullptr; ConstBufferPtr static_scale = nullptr; ConstBufferPtr static_scale_reciprocal = nullptr; LayerNormWeights() = default; LayerNormWeights(ConstBufferPtr& gamma, ConstBufferPtr& beta) : gamma(std::move(gamma)), beta(std::move(beta)) {} LayerNormWeights(BufferPtr& gamma, BufferPtr& beta) : gamma(std::move(gamma)), beta(std::move(beta)) {} LayerNormWeights(ConstBufferPtr& gamma, ConstBufferPtr& beta, ConstBufferPtr& static_scale, ConstBufferPtr& static_scale_reciprocal): gamma(std::move(gamma)), beta(std::move(beta)), static_scale(std::move(static_scale)), static_scale_reciprocal(std::move(static_scale_reciprocal)) {} LayerNormWeights(BufferPtr& gamma, BufferPtr& beta, BufferPtr& static_scale, BufferPtr& static_scale_reciprocal): gamma(std::move(gamma)), beta(std::move(beta)), static_scale(std::move(static_scale)), static_scale_reciprocal(std::move(static_scale_reciprocal)) {} }; typedef std::shared_ptr<const LayerNormWeights> LayerNormWeightsPtr; struct DenseWeights { ConstBufferPtr kernel = nullptr; ConstBufferPtr bias = nullptr; DenseWeights() = default; DenseWeights(BufferPtr& kernel) : kernel(std::move(kernel)) {}; DenseWeights(ConstBufferPtr& kernel) : kernel(std::move(kernel)) {}; DenseWeights(ConstBufferPtr& kernel, ConstBufferPtr& bias) : kernel(std::move(kernel)), bias(std::move(bias)) {}; DenseWeights(BufferPtr& kernel, BufferPtr& bias) : kernel(std::move(kernel)), bias(std::move(bias)) {}; }; typedef std::shared_ptr<DenseWeights> DenseWeightsPtr; struct AttentionLayerWeights { std::shared_ptr<const LayerNormWeights> pre_attention_layernorm; std::shared_ptr<const DenseWeights> qkv_weight; std::shared_ptr<const LayerNormWeights> attention_layernorm; std::shared_ptr<const LayerNormWeights> q_norm_weight; std::shared_ptr<const LayerNormWeights> k_norm_weight; std::shared_ptr<const DenseWeights> output_weight; std::shared_ptr<const DenseWeights> static_quant_weight; std::shared_ptr<const DenseWeights> static_scale_reciprocal_weight; std::shared_ptr<const DenseWeights> smoother_weight; std::shared_ptr<const DenseWeights> shift_weight; std::shared_ptr<const DenseWeights> linear_bias_slopes_weight; // mla weights std::shared_ptr<const DenseWeights> fusedqkrope_weight; std::shared_ptr<const DenseWeights> fusedqkrope_no_lora_weight; std::shared_ptr<const DenseWeights> q_b_weight; std::shared_ptr<const DenseWeights> kv_a_weight; std::shared_ptr<const DenseWeights> k_nope_weight; std::shared_ptr<const DenseWeights> k_rope_weight; std::shared_ptr<const DenseWeights> v_weight; std::shared_ptr<const LayerNormWeights> q_a_norm_weight; std::shared_ptr<const LayerNormWeights> kv_a_norm_weight; // mla decode weights std::shared_ptr<const DenseWeights> kc_weight; std::shared_ptr<const DenseWeights> vc_weight; // rope cos sin cache ConstBufferPtr rope_cos_sin_cache; }; struct FfnLayerWeights { std::shared_ptr<const DenseWeights> up_weight; std::shared_ptr<const DenseWeights> gate_weight; std::shared_ptr<DenseWeights> moe_gate_weight; std::shared_ptr<const DenseWeights> down_weight; std::shared_ptr<DenseWeights> moe_down_weight; std::shared_ptr<const DenseWeights> gate_up_weight; std::shared_ptr<const DenseWeights> moe_gating_weight; std::shared_ptr<const DenseWeights> smoother_weight; ConstBufferPtr act_scale; std::shared_ptr<const DenseWeights> intermediate_weight2_static_scale_weight; std::shared_ptr<const DenseWeights> intermediate_weight2_static_scale_reciprocal_weight; std::shared_ptr<const DenseWeights> intermediate_weight3_static_scale_weight; std::shared_ptr<const DenseWeights> intermediate_weight3_static_scale_reciprocal_weight; // these fields are for Qwen Mode model. // See https://github.com/huggingface/transformers/blo dingb/0f67ba1d741d65b07d549daf4ee157609ce4f9c1/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L803 std::shared_ptr<FfnLayerWeights> shared_expert; std::shared_ptr<const DenseWeights> shared_expert_gate; ConstBufferPtr e_score_correction_bias; // noaux_tc ConstBufferPtr log2phy; ConstBufferPtr logic_expert_cnt; }; struct LayerWeights { std::shared_ptr<const LayerNormWeights> pre_layernorm; AttentionLayerWeights self_attention_weights; std::shared_ptr<const LayerNormWeights> post_layernorm; std::shared_ptr<const DenseWeights> post_layernorm_quant_scale; FfnLayerWeights ffn_weights; std::shared_ptr<const LayerNormWeights> post_ffn_layernorm; // mtp std::shared_ptr<const LayerNormWeights> enorm; std::shared_ptr<const LayerNormWeights> hnorm; std::shared_ptr<const DenseWeights> eh_proj; std::shared_ptr<const LayerNormWeights> mtp_final_layernorm; }; // TODO: This Weights class might be refactor into a complete model description // which includes more info like norm type, activation type, etc. struct Weights { std::shared_ptr<const DenseWeights> embedding; std::shared_ptr<const DenseWeights> prefix_encoder_embedding; std::shared_ptr<const LayerNormWeights> pre_decoder_layernorm; std::shared_ptr<const DenseWeights> position_encoding; std::shared_ptr<const DenseWeights> token_type_embedding; std::vector<LayerWeights> layers; std::shared_ptr<const LayerNormWeights> final_layernorm; std::shared_ptr<const DenseWeights> linear_bias_slopes; std::shared_ptr<const DenseWeights> lm_head; }; using WeightsPtr = std::shared_ptr<const Weights>; } // namespace rtp_llm