maga_transformer/cpp/devices/OpData.h (838 lines of code) (raw):

#pragma once #include "maga_transformer/cpp/devices/Weights.h" #include "maga_transformer/cpp/devices/LoraWeights.h" #include "maga_transformer/cpp/devices/CommonDefines.h" #include "maga_transformer/cpp/utils/activation_types.h" #include "maga_transformer/cpp/utils/RopeConfig.h" #include "maga_transformer/cpp/utils/MlaConfig.h" #include "maga_transformer/cpp/stats/ExpertStats.h" #include "maga_transformer/cpp/core/Event.h" #include "maga_transformer/cpp/core/Buffer.h" #include "maga_transformer/cpp/core/QBuffer.h" #include "maga_transformer/cpp/utils/activation_types.h" #include "maga_transformer/cpp/utils/layernorm_types.h" #include "maga_transformer/cpp/utils/EnumUtils.h" #include "maga_transformer/cpp/utils/StackTrace.h" #include <cstddef> #include <optional> #include <functional> #include <sstream> #include <memory> #include <torch/extension.h> #include <torch/python.h> #include <type_traits> namespace rtp_llm { class GptModelInputs; } namespace rtp_llm { class DeviceBase; enum class OpErrorType { ERROR_NONE, ERROR_INVALID_ARGS, ERROR_RESOURCE_EXHAUSTED, ERROR_UNIMPLEMENTED, ERROR_INTERNAL, ERROR_UNKNOWN, }; enum class ParallelMode { TP = 0, DP = 1, DP_AND_TP = 2, FFN_TP = 3, EP = 4, EPLB = 5 }; enum class DeviceStream { DEFAULT = 0, }; class OpStatus { public: OpStatus(OpErrorType, const std::string& message = "") : error_type(OpErrorType::ERROR_NONE), error_message(message) {} static OpStatus make(OpErrorType error_type, const std::string& error_message = "") { return OpStatus(error_type, error_message); } static OpStatus OK() { return OpStatus(OpErrorType::ERROR_NONE); } bool ok() const { return error_type == OpErrorType::ERROR_NONE; } public: OpErrorType error_type; std::string error_message; }; class OpException : public std::exception { public: OpException(const OpStatus& status) : status_(status) { std::stringstream ss; ss << "OpException[" << (int32_t)status_.error_type << "]: " << status_.error_message << std::endl; RTP_LLM_LOG_INFO("%s", ss.str().c_str()); const auto stack = rtp_llm::getStackTrace(); RTP_LLM_STACKTRACE_LOG_INFO("%s", stack.c_str()); ss << stack; detail_str_ = ss.str(); if (std::getenv("FT_CORE_DUMP_ON_EXCEPTION")) { fflush(stdout); fflush(stderr); abort(); } } const char* what() const noexcept override { return detail_str_.c_str(); } const OpStatus& status() const { return status_; } private: OpStatus status_; mutable std::string detail_str_; }; using OptionalConstBufferRef = std::optional<std::reference_wrapper<const Buffer>>; using OptionalBufferRef = std::optional<std::reference_wrapper<Buffer>>; using OptionalConstVecBufferPtrRef = std::optional<std::reference_wrapper<const std::vector<BufferPtr>>>; using CloneOutput = BufferPtr; struct CloneParams { CloneParams(const Buffer& input, const AllocationType alloc_type = AllocationType::DEVICE, const BufferHints& hints = BufferHints(), bool overlapped = false) : input(input), alloc_type(alloc_type), hints(hints), overlapped(overlapped) {} const Buffer& input; const AllocationType alloc_type; const BufferHints& hints; bool overlapped = false; }; struct SliceParams { const Buffer& input; int64_t dim; int64_t start; int64_t end; int64_t step = 1; }; using SliceOutput = BufferPtr; struct CopyParams { const Buffer& dst; const Buffer& src; bool overlapped = false; const DeviceStream stream = DeviceStream::DEFAULT; void check() const { RTP_LLM_CHECK_WITH_INFO(src.type() == dst.type(), "copy dst[%d] and src[%d] need has same type.", src.type(), dst.type()); RTP_LLM_CHECK_WITH_INFO(src.sizeBytes() == dst.sizeBytes(), "src and dst copy size mismatch: [%s] vs [%s]", src.debugString().c_str(), dst.debugString().c_str()); } }; using SelectOutput = BufferPtr; enum SelectType { LAST = 0, FIRST = 1, }; struct SelectParams { const Buffer& input; const Buffer& index; size_t dim = 0; }; using TransposeOutput = BufferPtr; struct TransposeParams { const Buffer& input; bool overlapped = false; }; using ConvertOutput = BufferPtr; struct ConvertParams { const BufferPtr input; const DataType type; }; using ConcatOutput = BufferPtr; struct ConcatParams { const std::vector<BufferPtr>& inputs; const size_t dim = 0; }; struct SplitOutput { std::vector<BufferPtr> outputs; }; struct SplitParams { const Buffer& input; const std::vector<size_t>& split_sizes; const size_t dim = 0; bool overlapped = false; }; struct LayernormOutput { BufferPtr output; BufferPtr before_norm_output; }; struct AddBiasOutput { BufferPtr output; }; struct LayernormWithStrideParams { BufferPtr input; const std::optional<std::reference_wrapper<const LayerNormWeights>> norm_weight; double eps; NormType norm_type; size_t offset; // do normalize for each group in norm_group size_t norm_group_size; QScheme qscheme = QScheme::NoQuantize; bool in_place = true; }; struct LayernormParams { LayernormParams(BufferPtr input, BufferPtr before_norm_output, const std::optional<std::reference_wrapper<const LayerNormWeights>> norm_weight, OptionalConstBufferRef residual1 = std::nullopt, OptionalConstBufferRef residual2 = std::nullopt, OptionalConstBufferRef bias = std::nullopt, double alpha = 1.0f, double eps = 1e-5, bool is_inplace = true, bool return_normed_output = false, NormType norm_type = NormType::layernorm, QScheme qscheme = QScheme::NoQuantize, bool attn_swap_comm_buffer = false, bool ffn_swap_comm_buffer = false) : input(std::move(input)), before_norm_output(std::move(before_norm_output)), norm_weight(norm_weight), residual1(residual1), residual2(residual2), bias(bias), norm_type(norm_type), alpha(alpha), eps(eps), return_normed_output(return_normed_output), is_inplace(is_inplace), qscheme(qscheme), attn_swap_comm_buffer(attn_swap_comm_buffer), ffn_swap_comm_buffer(ffn_swap_comm_buffer) {}; BufferPtr input; BufferPtr before_norm_output; const std::optional<std::reference_wrapper<const LayerNormWeights>> norm_weight; const OptionalConstBufferRef residual1; const OptionalConstBufferRef residual2; const OptionalConstBufferRef bias; const NormType norm_type; const double alpha; const double eps; const bool return_normed_output; const bool is_inplace; const QScheme qscheme; bool attn_swap_comm_buffer = false; bool ffn_swap_comm_buffer = false; }; enum GemmType : size_t { InvalidGemm = 0, BufferA_BufferB_BufferC_2DGemm, BufferA_BufferB_BufferC_3DGemm, QBufferA_BufferB_BufferC_2DGemm, BufferA_QBufferB_BufferC_2DGemm, QBufferA_QBufferB_BufferC_2DGemm, }; struct AddBiasParams { BufferPtr input; const Buffer& bias; bool inplace = true; }; // D = alpha * op(A) * op(B) + beta * C // shapes of A, B, C, D have two options: [m, k], [k, n], [m, n] / [1, n], [m, n] // or [bs, m, k], [bs, k, n], [bs, m, n], [bs, m, n] where bs is batch_size // D is optional, if not passed, it will be allocated by the op struct GemmParams { GemmParams(const Buffer& A, const Buffer& B, OptionalConstBufferRef C = std::nullopt, BufferPtr D = nullptr, const DataType compute_type = DataType::TYPE_INVALID, TransposeOperation transA = TransposeOperation::NONE, TransposeOperation transB = TransposeOperation::NONE, const ActivationType activationType = ActivationType::Identity, const float alpha = 1.0f, const float beta = 0.0f, int math_sm_count = 0, void* stream = nullptr): A(A), B(B), C(C), D(D), compute_type(compute_type), transA(transA), transB(transB), activationType(activationType), alpha(alpha), beta(beta), math_sm_count(math_sm_count), stream(stream) {} const Buffer& A; const Buffer& B; OptionalConstBufferRef C; BufferPtr D; const DataType compute_type = DataType::TYPE_INVALID; // If passed invalid type, op should infer type const TransposeOperation transA = TransposeOperation::NONE; const TransposeOperation transB = TransposeOperation::NONE; ActivationType activationType = ActivationType::Identity; const float alpha = 1.0f; const float beta = 0.0f; mutable int math_sm_count = 0; void* stream = nullptr; void check() const; GemmType dispatch() const; }; struct GroupedGemmOutput { std::vector<BufferPtr> output; }; // C = alpha * op(A) * op(B) + beta * C // shapes of each A, B, C needs to be [m, k], [k, n], [m, n] struct GroupedGemmParams { const std::vector<BufferPtr>& A; const std::vector<BufferPtr>& B; std::optional<std::vector<BufferPtr>> C = std::nullopt; const float alpha = 1.0f; const float beta = 1.0f; void check() const; }; using MultiplyOutput = BufferPtr; // output = A * B // A: [m], B: [m] or [m, dim_1, ..., dim_n] struct MultiplyParams { const Buffer& A; const Buffer& B; BufferPtr output = nullptr; }; struct EmbeddingLookupParams { const Buffer& combo_tokens; const Buffer& embedding_table; double input_embedding_scalar = 1; OptionalConstBufferRef text_tokens_mask; OptionalConstBufferRef position_ids; OptionalConstBufferRef position_table; OptionalConstBufferRef token_types; OptionalConstBufferRef token_type_table; }; struct KvCacheInfo { int layer_num; BufferPtr kv_cache_block_id; // [batch_size, block_nums], kv cache block offset BufferPtr k_cache_buffer; // [block_nums, head, seq_size_per_block, size_per_head] BufferPtr v_cache_buffer; // [block_nums, head, seq_size_per_block, size_per_head] BufferPtr k_scale_buffer; // [block_nums, head, seq_size_per_block] BufferPtr v_scale_buffer; // [block_nums, head, seq_size_per_block] }; struct MultimodalEmbeddingParams { const BufferPtr& word_embeddings; OptionalConstVecBufferPtrRef multimodal_features; OptionalConstBufferRef multimodal_locs; }; using MultimodalEmbeddingOutput = BufferPtr; struct CacheStoreInputs { BufferPtr input_lengths_host; BufferPtr prefix_lengths_host; BufferPtr host_kv_cache_offset; }; using FlashInferAttnParamsPtr = std::shared_ptr<void>; struct AttentionCommonInputs { // see detailed comments at GptModelInputs ConstBufferPtr input_lengths; // int32_t, [decoder_batch_size + context_batch_size] ConstBufferPtr sequence_lengths; // int32_t, [decoder_batch_size] std::optional<KvCacheInfo> kv_cache; std::optional<CacheStoreInputs> cache_store_inputs; ConstBufferPtr cu_seqlens; ConstBufferPtr cu_kv_seqlens; ConstBufferPtr padding_offset; size_t context_batch_size = 0; size_t decoder_batch_size = 0; size_t context_max_seq_len = 0; size_t decoder_max_seq_len = 0; size_t context_token_num = 0; BufferPtr position_ids; BufferPtr attention_mask; ConstBufferPtr linear_bias_slopes; BufferPtr prefix_prompt_lengths; int32_t max_prefix_length = 0; lora::AttentionLayerLoraInput lora_input; int layer_id = 0; BufferPtr request_id; // [context_batch_size] BufferPtr request_pd_separation; // [context_batch_size] std::vector<std::string> cache_keys; // [context_batch_size] size_t k_block_size = 0; size_t v_block_size = 0; size_t scale_block_size = 0; bool pd_separation = false; bool warmup; FlashInferAttnParamsPtr prefill_flash_infer_attn_params; FlashInferAttnParamsPtr decode_flash_infer_attn_params; }; struct AttentionConfigs { size_t head_num; size_t kv_head_num; size_t size_per_head; size_t hidden_size; // rotary embending config RopeConfig rope_config; //kv cache block size_t tokens_per_block; AttentionMaskType mask_type = noMask; float q_scaling = 1.0f; bool fuse_qkv_add_bias = true; bool use_logn_attn = false; // mla config bool use_mla = false; size_t q_lora_rank; size_t kv_lora_rank; size_t nope_head_dim; size_t rope_head_dim; size_t v_head_dim; // softmax config float softmax_extra_scale = 1.0f; KvCacheDataType kv_cache_dtype = KvCacheDataType::BASE; }; using AttentionModuleOutput = void; struct AttentionModuleParams { const int32_t layer_id; // qkv shape[h_token_num, (head_num + 2 * kv_head_num) * size_per_head] const Buffer& input; Buffer& output; // shape [token_num, size_per_head] AttentionCommonInputs& common; const AttentionLayerWeights& weights; const AttentionConfigs& configs; const QScheme qscheme; }; struct MlaRotaryWriteKVCacheParams { const Buffer& q; BufferPtr fused_dest_q; const Buffer& fused_qkv; const int64_t kv_offset; FlashInferAttnParamsPtr flash_infer_params; // prefill or decode AttentionCommonInputs& common; const AttentionLayerWeights& weights; const AttentionConfigs& configs; const QScheme qscheme; }; struct MlaAttentionModuleParams { const int32_t layer_id; const Buffer& q; const Buffer& fused_qkv; const int64_t kv_offset; BufferPtr qkv_output; // shape [token_num, hidden_size] AttentionCommonInputs& common; const AttentionLayerWeights& weights; const AttentionConfigs& configs; const QScheme qscheme; bool is_prefill = false; }; struct WriteMTPHiddenStatesParams { bool pd_separation = false; bool warmup; size_t context_batch_size = 0; size_t decoder_batch_size = 0; BufferPtr request_pd_separation; BufferPtr request_id; BufferPtr hidden_states; BufferPtr lm_output_indexes; }; struct WriteCacheParams { AttentionCommonInputs& common; const AttentionConfigs& configs; bool mla_kvcache = false; WriteCacheParams(const AttentionModuleParams& params) : common(params.common), configs(params.configs) {} WriteCacheParams(const MlaAttentionModuleParams& params) : common(params.common), configs(params.configs), mla_kvcache(true) {} }; struct AttentionLayerOutput { BufferPtr hidden_states; }; struct LayerNormConfig { double eps; NormType norm_type; }; struct AttentionLayerParams { int32_t layer_id; const Buffer& input; BufferPtr output; const AttentionConfigs& configs; const AttentionLayerWeights& weights; AttentionCommonInputs& common; const OptionalConstBufferRef residual; // for intel xft const LayerNormConfig ln_params; const QScheme qscheme; bool enable_sp; size_t pad_token_num; }; struct MoeConfigs { size_t expert_num; size_t extra_expert_num = 0; size_t top_k; bool normalize_expert_scale = false; int64_t moe_inter_padding_size = 0; bool has_moe_norm = false; size_t ep_rank = 0; size_t ep_size = 1; size_t tp_rank = 0; size_t tp_size = 1; size_t dp_rank = 0; size_t dp_size = 1; int scoring_func = 0; // 0: softmax, 1: sigmoid int topk_group = 1; int n_group = 1; bool enable_eplb = false; // NOTE(yinzhi): not used yet EplbBalanceMethod balance_method = EplbBalanceMethod::EQUAL; }; struct FfnConfigs { ActivationType activation_type; std::optional<MoeConfigs> moe_configs = std::nullopt; }; struct DeepEPDispatchOutput; struct DeepEPDispatchOutputLowLatency; struct MoeGateSelectOutput; struct MoeCombineParams { BufferPtr input; BufferPtr indices; BufferPtr output; std::vector<size_t> input_split_sizes; std::vector<size_t> output_split_sizes; MoeConfigs moe_configs; size_t origin_token_num; bool overlapped = false; std::shared_ptr<DeepEPDispatchOutput> deep_ep_output; std::shared_ptr<DeepEPDispatchOutputLowLatency> deep_ep_ll_output; std::shared_ptr<MoeGateSelectOutput> select_output; BufferPtr expert_ids; BufferPtr expert_scales; bool sp_model = false; }; struct MoeCombineOutput { BufferPtr all_output; BufferPtr scatter_output; MoeCombineParams params; DeviceHookPtr comm_barrier_hook; }; struct FfnLayerOutput { BufferPtr hidden_states; DeviceHookPtr comm_barrier_hook; std::optional<MoeCombineOutput> moe_combine_output; }; struct FfnLayerParams { FfnLayerParams(const Buffer& input, const FfnConfigs& configs, const FfnLayerWeights& weights, const OptionalConstBufferRef residual = std::nullopt, const QScheme qscheme = QScheme::NoQuantize, BufferPtr output = nullptr, bool enable_sp = false, bool sp_model = false): input(input), configs(configs), weights(weights), residual(residual), qscheme(qscheme), output(std::move(output)), enable_sp(enable_sp), sp_model(sp_model) {} const Buffer& input; const FfnConfigs& configs; const FfnLayerWeights& weights; const OptionalConstBufferRef residual; // for intel xft OptionalExpertStats expert_stats = std::nullopt; const QScheme qscheme; BufferPtr output; lora::FfnLayerLoraInput lora_input; bool enable_sp; bool sp_model = false; }; struct MoeDispatchOutput { BufferPtr hidden; BufferPtr expert_ids; BufferPtr expert_scales; BufferPtr indices; const std::vector<size_t> input_split_sizes; const std::vector<size_t> output_split_sizes; const std::vector<BufferPtr> dispatch_src_buffers; // to make them outlive async sendrecv const BufferPtr concated_src_buffers; // to make them outlive async sendrecv const BufferPtr split_dst_buffers; // to make them outlive async sendrecv DeviceHookPtr comm_barrier_hook; std::shared_ptr<DeepEPDispatchOutput> deep_ep_output; std::shared_ptr<DeepEPDispatchOutputLowLatency> deep_ep_ll_output; }; struct MoeGateSelectOutput { BufferPtr expert_ids; BufferPtr expert_scales; std::shared_ptr<DeepEPDispatchOutputLowLatency> deep_ep_ll_output = nullptr; }; struct MoeDispatchParams { MoeDispatchParams(const Buffer& input, const Buffer& expert_ids, const Buffer& expert_scales, const MoeConfigs& moe_configs, bool overlapped = false, const QScheme qscheme = QScheme::NoQuantize, OptionalExpertStats expert_stats = std::nullopt, bool sp_model = false): input(input), expert_ids(expert_ids), expert_scales(expert_scales), moe_configs(moe_configs), overlapped(overlapped), qscheme(qscheme), expert_stats(expert_stats), sp_model(sp_model) {} const Buffer& input; const Buffer& expert_ids; const Buffer& expert_scales; const MoeConfigs& moe_configs; bool overlapped = false; const QScheme qscheme; OptionalExpertStats expert_stats = std::nullopt; bool sp_model = false; }; struct MoeEpPlanParams { BufferPtr expert_ids; BufferPtr expert_scales; const FfnLayerParams& params; bool overlapped = false; }; struct MoeEpPlanOutput { BufferPtr all_token_indices; BufferPtr balanced_expert_ids; std::vector<size_t> input_split_sizes; std::vector<size_t> output_split_sizes; }; struct MoeBalanceOutput { BufferPtr balance_expert_ids; }; struct MoeBalanceParams { const BufferPtr experts_ids_host; const FfnLayerParams& params; }; // for deepseek decode micro batch struct MoEInsertionParams { MoEInsertionParams(const MoeDispatchOutput& dispatched_output, const FfnLayerParams& ffn_params, std::shared_ptr<MoeGateSelectOutput> gate_output, size_t origin_token_num) : dispatched_output(dispatched_output), ffn_params(ffn_params), gate_output(std::move(gate_output)), origin_token_num(origin_token_num) {} MoeDispatchOutput dispatched_output; FfnLayerParams ffn_params; std::shared_ptr<MoeGateSelectOutput> gate_output; size_t origin_token_num; }; struct MoEInsertionReturns { MoeCombineOutput combine_output; }; struct GreedyParams { const Buffer& logits; // [batch_size, vocab_size_padded] const Buffer& input_lengths; // [batch_size] const Buffer& sequence_lengths; // [batch_size] Buffer& token_ids; // [batch_size, max_input_length + 1] const size_t step; const Buffer& top_k; const Buffer& top_p; const Buffer& temperature; OptionalBufferRef random_seed; OptionalBufferRef repetition_penalty; OptionalBufferRef min_lengths; OptionalBufferRef eos_ids; OptionalBufferRef no_repeat_ngram_size; OptionalBufferRef cum_log_probs; OptionalBufferRef output_log_probs; OptionalBufferRef output_all_probs; }; struct GreedyOutput { BufferPtr success; // BufferPtr new_tokens; }; struct BeamSearchParams { const Buffer& logits; Buffer& token_ids; Buffer& input_lengths; Buffer& sequence_lengths; Buffer& cum_log_probs; Buffer& beam_index; }; struct BroadcastParams { const std::vector<BufferPtr>& buffers; const int64_t root; ParallelMode mode = ParallelMode::TP; bool overlapped = false; }; enum class ReduceOp { Sum = 0, Prod = 1, Max = 2, Min = 3, Avg = 4, }; struct PrepareAllReduceParams { const BufferPtr buffer; const ReduceOp op; ParallelMode mode = ParallelMode::TP; }; struct PrepareAllReduceOutput { const BufferPtr buffer; }; struct AllReduceParams { const BufferPtr buffer; const ReduceOp op; bool overlapped = false; ParallelMode mode = ParallelMode::TP; const BufferPtr dest = nullptr; }; struct AllReduceOutput { const BufferPtr buffer; }; struct AllGatherParams { const std::vector<BufferPtr>& recv_buffers; ParallelMode mode = ParallelMode::TP; std::vector<BufferPtr> send_buffers; bool inplace = true; bool overlapped = false; }; struct ReduceScatterParams { const BufferPtr send_buffer; const BufferPtr recv_buffer; const ReduceOp op; ParallelMode mode = ParallelMode::TP; bool overlapped = false; }; struct AllToAllParams { const std::vector<BufferPtr> buffers; const std::vector<size_t> input_split_sizes; const std::vector<size_t> output_split_sizes; bool overlapped = false; ParallelMode mode = ParallelMode::DP_AND_TP; }; struct AllToAllOutput { std::vector<BufferPtr> outputs; BufferPtr concated_input; BufferPtr output_to_split; DeviceHookPtr comm_barrier_hook; }; // output = act(input) + bias struct ActivationParams { ActivationType atype; // can be nullptr for fuse gemm with activation BufferPtr states; const OptionalConstBufferRef bias = std::nullopt; const OptionalConstBufferRef gate = std::nullopt; const OptionalConstBufferRef gate_bias = std::nullopt; const OptionalConstBufferRef act_scale = std::nullopt; BufferPtr output_buffer = nullptr; bool fuse_gate_up = false; ActivationParams(ActivationType atype, BufferPtr states, const OptionalConstBufferRef bias, const OptionalConstBufferRef gate, const OptionalConstBufferRef gate_bias, const OptionalConstBufferRef act_scale, BufferPtr output_buffer = nullptr, bool fuse_gate_up = false): atype(atype), states(states), bias(bias), gate(gate), gate_bias(gate_bias), act_scale(act_scale), output_buffer(output_buffer), fuse_gate_up(fuse_gate_up) {} ActivationParams(ActivationType atype, BufferPtr states): atype(atype), states(states), bias(std::nullopt), gate(std::nullopt), gate_bias(std::nullopt), act_scale(std::nullopt) {}; }; // softmax op is inplace-update, thus output buffer is same as input struct SoftmaxParams { const BufferPtr input; const OptionalConstBufferRef mask = std::nullopt; const OptionalConstBufferRef bias = std::nullopt; float scale = 1.0f; const DataType output_t = DataType::TYPE_INVALID; const OptionalConstBufferRef linear_bias_slopes = std::nullopt; }; struct LossParams { const Buffer& logits; const Buffer& labels; }; using LossOutput = BufferPtr; struct MaskParams { public: const Buffer& input_lengths; const Buffer& prefix_lengths; DataType dtype; bool is_causal; }; using MaskOutput = BufferPtr; struct DevicePrepParams { const AttentionConfigs& configs; const BufferPtr &prefix_lengths; const BufferPtr &sequence_lengths; const BufferPtr &input_lengths; const BufferPtr &kv_cache_block_id; const BufferPtr &kv_cache_block_id_device; DataType attn_dtype = DataType::TYPE_INVALID; size_t context_batch_size = 0; size_t decoder_batch_size = 0; bool has_kv_cache = true; bool diff_qkv_len = false; bool has_alibi_slopes = false; }; struct DevicePrepOutput { bool need_mask = true; FlashInferAttnParamsPtr decode_flash_infer_attn_params; FlashInferAttnParamsPtr prefill_flash_infer_attn_params; }; struct LoraLinearOutput { BufferPtr output; }; struct AllGatherLoraLinearOutput { BufferPtr output; BufferPtr all_gather_recv_buffer; }; struct ReduceScatterLoraLinearOutput { BufferPtr output; BufferPtr reduce_scatter_recv_buffer; }; struct LoraLinearParams { LoraLinearParams(GemmParams& gemm_params, lora::LoraOpInputPtr lora_input = nullptr) : gemm_params(gemm_params), lora_input(lora_input) {} GemmParams& gemm_params; lora::LoraOpInputPtr lora_input; }; struct LoraLinearReduceScatterParams { const LoraLinearParams& lora_linear_params; const BufferPtr& rs_recv_buffer; QScheme qscheme; DataType output_type; ParallelMode mode = ParallelMode::TP; LoraLinearReduceScatterParams(const LoraLinearParams& lora_linear_params, const BufferPtr& rs_recv_buffer, QScheme qscheme, DataType output_type, ParallelMode mode = ParallelMode::TP): lora_linear_params(lora_linear_params), rs_recv_buffer(rs_recv_buffer), qscheme(qscheme), output_type(output_type), mode(mode) {} }; struct AllGatherLoraLinearParams { const LoraLinearParams& lora_linear_params; const BufferPtr& ag_send_buffer; BufferPtr ag_recv_buffer; QScheme qscheme; DataType output_type; ParallelMode mode = ParallelMode::TP; AllGatherLoraLinearParams(const LoraLinearParams& lora_linear_params, const BufferPtr& ag_send_buffer, BufferPtr ag_recv_buffer, QScheme qscheme, DataType output_type, ParallelMode mode = ParallelMode::TP): lora_linear_params(lora_linear_params), ag_send_buffer(ag_send_buffer), ag_recv_buffer(ag_recv_buffer), qscheme(qscheme), output_type(output_type), mode(mode){} }; struct PrepareCommBufferParams { const size_t max_batch_seq_len; const size_t attn_rs_hidden; const size_t ffn_rs_hidden; const size_t attn_ag_hidden; const size_t ffn_ag_hidden; DataType rs_output_type; DataType ag_input_type; bool enable_per_token_scale = false; bool enable_ffn_tp = false; PrepareCommBufferParams(size_t max_batch_seq_len, size_t attn_rs_hidden, size_t ffn_rs_hidden, size_t attn_ag_hidden, size_t ffn_ag_hidden, DataType rs_output_type, DataType ag_input_type, bool enable_per_token_scale = false, bool enable_ffn_tp = false): max_batch_seq_len(max_batch_seq_len), attn_rs_hidden(attn_rs_hidden), ffn_rs_hidden(ffn_rs_hidden), attn_ag_hidden(attn_ag_hidden), ffn_ag_hidden(ffn_ag_hidden), rs_output_type(rs_output_type), ag_input_type(ag_input_type), enable_per_token_scale(enable_per_token_scale), enable_ffn_tp(enable_ffn_tp) {} }; struct LoraLinearWithActivationParams { const LoraLinearParams& lora_linear_params; const ActivationParams& activation_params; LoraLinearWithActivationParams(const LoraLinearParams& lora_linear_params, const ActivationParams& activation_params): lora_linear_params(lora_linear_params), activation_params(activation_params) {} }; struct QuantizeParams { const Buffer& input; DataType qtype; size_t axis; QScheme qscheme; // for soomth quantize OptionalConstBufferRef smoother; OptionalConstBufferRef shift; // for static quantize OptionalConstBufferRef static_scale; OptionalConstBufferRef static_scale_reciprocal; // for groupwise quantize int64_t groupSize; int64_t paddingSize = 0; QuantizeParams(const Buffer& input, DataType qtype, size_t axis, QScheme qscheme, OptionalConstBufferRef smoother, OptionalConstBufferRef shift, OptionalConstBufferRef static_scale, OptionalConstBufferRef static_scale_reciprocal): input(input), qtype(qtype), axis(axis), qscheme(qscheme), smoother(smoother), shift(shift), static_scale(static_scale), static_scale_reciprocal(static_scale_reciprocal), groupSize(64) {} QuantizeParams(const Buffer& input, DataType qtype, size_t axis): input(input), qtype(qtype), axis(axis), qscheme(QScheme::Qint8PerToken), groupSize(64) {} QuantizeParams(const Buffer& input, DataType qtype, size_t axis, int64_t groupSize): input(input), qtype(qtype), axis(axis), qscheme(QScheme::Qint8PerToken), groupSize(groupSize) {} QuantizeParams(const Buffer& input, DataType qtype, size_t axis, QScheme qscheme, int64_t paddingSize = 0): input(input), qtype(qtype), axis(axis), qscheme(qscheme), paddingSize(paddingSize) {} }; } // namespace rtp_llm