maga_transformer/cpp/th_op/GptInitParameter.cc (378 lines of code) (raw):

#include "maga_transformer/cpp/th_op/GptInitParameter.h" #include "maga_transformer/cpp/th_op/GptInitParameterRegister.h" namespace rtp_llm { SpecialTokens::SpecialTokens() {} GptInitParameter::GptInitParameter() {} GptInitParameter::GptInitParameter(int64_t head_num, int64_t size_per_head, int64_t num_layers, int64_t max_seq_len, int64_t vocab_size, int64_t hidden_size): head_num_(head_num), size_per_head_(size_per_head), num_layers_(num_layers), hidden_size_(hidden_size), max_seq_len_(max_seq_len), vocab_size_(vocab_size) { } void GptInitParameter::insertMultiTaskPromptTokens(std::string task_id, std::vector<int64_t> tokens_id) { std::vector<int> new_tokens_id; // to convert tokens of type int64_t to type int32_t for (auto token_id : tokens_id) { new_tokens_id.push_back(token_id); } multi_task_prompt_tokens_[task_id] = new_tokens_id; } void GptInitParameter::setLayerNormType() { layernorm_type_ = getLayerNormType(layernorm_type_str_); } void GptInitParameter::setNormType() { norm_type_ = getNormType(norm_type_str_); } void GptInitParameter::setTaskType(std::string task) { if (task == "DENSE_EMBEDDING") { task_type_ = TaskType::DENSE_EMBEDDING; } else if (task == "ALL_EMBEDDING") { task_type_ = TaskType::ALL_EMBEDDING; } else if (task == "SPARSE_EMBEDDING") { task_type_ = TaskType::SPARSE_EMBEDDING; } else if (task == "COLBERT_EMBEDDING") { task_type_ = TaskType::COLBERT_EMBEDDING; } else if (task == "LANGUAGE_MODEL") { task_type_ = TaskType::LANGUAGE_MODEL; } else if (task == "SEQ_CLASSIFICATION") { task_type_ = TaskType::SEQ_CLASSIFICATION; } else if (task == "RERANKER") { task_type_ = TaskType::RERANKER; } else if (task == "LINEAR_SOFTMAX") { task_type_ = TaskType::LINEAR_SOFTMAX; } else if (task == "BGE_M3") { task_type_ = TaskType::BGE_M3; } else { RTP_LLM_CHECK_WITH_INFO(false, "unkown task type: " + task); } } void GptInitParameter::setActivationType() { activation_type_ = getActivationType(activation_type_str_); } void GptInitParameter::setKvCacheDataType() { kv_cache_data_type_ = getDataType(kv_cache_data_type_str_); } bool GptInitParameter::isGatedActivation() const { return rtp_llm::isGatedActivation(activation_type_); } bool GptInitParameter::isKvCacheQuant() const { return kv_cache_data_type_ == DataType::TYPE_FP8_E4M3 || kv_cache_data_type_ == DataType::TYPE_INT8; } void QuantAlgo::setQuantAlgo(const std::string &quant_method, int64_t bits, int64_t group_size) { if (quant_method == "gptq") { quant_method_ = GptQ; weight_bits_ = bits; group_size_ = group_size; } else if (quant_method == "awq") { quant_method_ = Awq; weight_bits_ = bits; group_size_ = group_size; } else if (quant_method == "weight_only_per_col") { quant_method_ = WeightOnlyPerCol; weight_bits_ = bits; if (weight_bits_ != 8) { throw std::invalid_argument("invalid weight_bits: " + std::to_string(weight_bits_)); } } else if (quant_method == "smooth_quant") { quant_method_ = SmoothQuant; weight_bits_ = 8; } else if (quant_method == "omni_quant") { quant_method_ = OmniQuant; weight_bits_ = 8; } else if (quant_method == "pertensor_quant") { quant_method_ = PerTensorQuant; weight_bits_ = 8; } else if (quant_method == "fp8") { quant_method_ = FP8Quant; weight_bits_ = 8; group_size_ = group_size; } else { throw std::invalid_argument("unknown quant_method: " + quant_method); } if (weight_bits_ != 4 && weight_bits_ != 8) { throw std::invalid_argument("invalid weight_bits: " + std::to_string(weight_bits_)); } if (group_size_ != 0 && group_size_ != 64 && group_size_ != 128) { throw std::invalid_argument("invalid group_size: " + std::to_string(group_size_)); } } RopeConfig GptInitParameter::getRopeConfig() const { RopeConfig rope_config; rope_config.style = (RopeStyle)rotary_embedding_style_; rope_config.dim = rotary_embedding_dim_; rope_config.base = rotary_embedding_base_; rope_config.scale = rotary_embedding_scale_; rope_config.max_pos = org_embedding_max_pos_; rope_config.factor1 = rotary_factor1_; rope_config.factor2 = rotary_factor2_; rope_config.mscale = rotary_embedding_mscale_; rope_config.offset = rotary_embedding_offset_; rope_config.index_factor = position_id_len_factor_; if (rope_config.style == RopeStyle::Mrope) { rope_config.mrope_dim1 = mrope_section_[0]; rope_config.mrope_dim2 = mrope_section_[1]; rope_config.mrope_dim3 = mrope_section_[2]; } return rope_config; } // is not pd-sep bool GptInitParameter::isPDFusion() const { return !pd_separation_ && !use_cache_store_; } // is prefill in p-d sep bool GptInitParameter::isPrefillRole() const { return pd_separation_ && use_cache_store_; } // is decode in p-d sep bool GptInitParameter::isDecodeRole() const { return !pd_separation_ && use_cache_store_; } void registerGptInitParameter(py::module m) { py::enum_<MlaOpsType>(m, "MlaOpsType") .value("AUTO", MlaOpsType::AUTO) .value("MHA", MlaOpsType::MHA) .value("FLASH_INFER", MlaOpsType::FLASH_INFER) .value("FLASH_MLA", MlaOpsType::FLASH_MLA); py::enum_<EplbMode>(m, "EplbMode") .value("NONE", EplbMode::NONE) .value("STATS", EplbMode::STATS) .value("EPLB", EplbMode::EPLB) .value("ALL", EplbMode::ALL); #define DEF_PROPERTY(name) .def_readwrite(#name, &RoleSpecialTokens::name##_) #define REGISTER_PROPERTYS \ DEF_PROPERTY(token_ids) \ DEF_PROPERTY(eos_token_ids) pybind11::class_<RoleSpecialTokens>(m, "RoleSpecialTokens") .def(pybind11::init<>()) REGISTER_PROPERTYS; #undef DEF_PROPERTY #undef REGISTER_PROPERTYS #define DEF_PROPERTY(name) .def_readwrite(#name, &SpecialTokens::name##_) #define REGISTER_PROPERTYS \ DEF_PROPERTY(bos_token_id) \ DEF_PROPERTY(eos_token_id) \ DEF_PROPERTY(decoder_start_token_id) \ DEF_PROPERTY(user) \ DEF_PROPERTY(assistant) \ DEF_PROPERTY(system) \ DEF_PROPERTY(stop_words_id_list) \ DEF_PROPERTY(stop_words_str_list) \ DEF_PROPERTY(pad_token_id) pybind11::class_<SpecialTokens>(m, "SpecialTokens") .def(pybind11::init<>()) REGISTER_PROPERTYS; #undef DEF_PROPERTY #undef REGISTER_PROPERTYS pybind11::class_<QuantAlgo>(m, "QuantAlgo") .def(pybind11::init<>()) // quant_pre_scales .def("setQuantAlgo", &QuantAlgo::setQuantAlgo, py::arg("quant_method"), py::arg("bits"), py::arg("group_size")) .def("isWeightOnlyPerCol", &QuantAlgo::isWeightOnlyPerCol) .def("isGptq", &QuantAlgo::isGptq) .def("isAwq", &QuantAlgo::isAwq) .def("isSmoothQuant", &QuantAlgo::isSmoothQuant) .def("isOmniQuant", &QuantAlgo::isOmniQuant) .def("isPerTensorQuant", &QuantAlgo::isPerTensorQuant) .def("isFp8", &QuantAlgo::isFp8) .def("isQuant", &QuantAlgo::isQuant) .def("isGroupwise", &QuantAlgo::isGroupwise) .def("getGroupSize", &QuantAlgo::getGroupSize) .def("getWeightBits", &QuantAlgo::getWeightBits) .def("getActivationBits", &QuantAlgo::getActivationBits) .def(py::pickle( [](const QuantAlgo& quant_algo) { return py::make_tuple(int(quant_algo.getQuantMethod()), int(quant_algo.getWeightBits()), int(quant_algo.getGroupSize()), int(quant_algo.getActivationBits())); } , [](py::tuple t){ return QuantAlgo(QuantMethod(t[0].cast<int>()), t[1].cast<int>(), t[2].cast<int>()); })); #define DEF_PROPERTY(name, member) .def_readwrite(#name, &GptInitParameter::member) #define REGISTER_PROPERTYS \ DEF_PROPERTY(head_num, head_num_) \ DEF_PROPERTY(head_num_kv, head_num_kv_) \ DEF_PROPERTY(size_per_head, size_per_head_) \ DEF_PROPERTY(max_seq_len, max_seq_len_) \ DEF_PROPERTY(vocab_size, vocab_size_) \ DEF_PROPERTY(input_vocab_size, input_vocab_size_) \ DEF_PROPERTY(hidden_size, hidden_size_) \ DEF_PROPERTY(type_vocab_size, type_vocab_size_) \ DEF_PROPERTY(embedding_size, embedding_size_) \ DEF_PROPERTY(gen_num_per_circle, gen_num_per_circle_) \ DEF_PROPERTY(inter_size, inter_size_) \ DEF_PROPERTY(inter_padding_size, inter_padding_size_) \ DEF_PROPERTY(moe_inter_padding_size, moe_inter_padding_size_) \ DEF_PROPERTY(is_sparse_head, is_sparse_head_) \ DEF_PROPERTY(layer_head_num, layer_head_num_) \ DEF_PROPERTY(layer_head_num_kv, layer_head_num_kv_) \ DEF_PROPERTY(layer_inter_size, layer_inter_size_) \ DEF_PROPERTY(layer_inter_padding_size, layer_inter_padding_size_) \ DEF_PROPERTY(num_layers, num_layers_) \ DEF_PROPERTY(layer_num, num_layers_) \ DEF_PROPERTY(num_valid_layer, num_valid_layer_) \ DEF_PROPERTY(expert_num, expert_num_) \ DEF_PROPERTY(moe_k, moe_k_) \ DEF_PROPERTY(moe_normalize_expert_scale, moe_normalize_expert_scale_) \ DEF_PROPERTY(moe_style, moe_style_) \ DEF_PROPERTY(moe_layer_index, moe_layer_index_) \ DEF_PROPERTY(scoring_func, scoring_func_) \ DEF_PROPERTY(layernorm_eps, layernorm_eps_) \ /* In python, the following types use strings for branch condition */ \ /* Everytime type changes, corresponding set type function should */ \ /* be called. */ \ DEF_PROPERTY(layernorm_type, layernorm_type_str_) \ DEF_PROPERTY(norm_type, norm_type_str_) \ DEF_PROPERTY(activation_type, activation_type_str_) \ DEF_PROPERTY(rotary_embedding_dim, rotary_embedding_dim_) \ DEF_PROPERTY(kv_cache_data_type, kv_cache_data_type_str_) \ DEF_PROPERTY(rotary_embedding_style, rotary_embedding_style_) \ DEF_PROPERTY(position_ids_style, position_ids_style_) \ DEF_PROPERTY(position_id_len_factor, position_id_len_factor_) \ DEF_PROPERTY(rotary_embedding_base, rotary_embedding_base_) \ DEF_PROPERTY(rotary_embedding_scale, rotary_embedding_scale_) \ DEF_PROPERTY(org_embedding_max_pos, org_embedding_max_pos_) \ DEF_PROPERTY(rotary_factor1, rotary_factor1_) \ DEF_PROPERTY(rotary_factor2, rotary_factor2_) \ DEF_PROPERTY(mrope_section, mrope_section_) \ DEF_PROPERTY(input_embedding_scalar, input_embedding_scalar_) \ DEF_PROPERTY(residual_scalar, residual_scalar_) \ DEF_PROPERTY(use_norm_input_residual, use_norm_input_residual_) \ DEF_PROPERTY(use_norm_attn_out_residual, use_norm_attn_out_residual_) \ DEF_PROPERTY(data_type, data_type_) \ DEF_PROPERTY(has_positional_encoding, has_positional_encoding_) \ DEF_PROPERTY(has_pre_decoder_layernorm, has_pre_decoder_layernorm_) \ DEF_PROPERTY(has_post_decoder_layernorm, has_post_decoder_layernorm_) \ DEF_PROPERTY(has_moe_norm, has_moe_norm_) \ DEF_PROPERTY(logit_scale, logit_scale_) \ DEF_PROPERTY(has_lm_head, has_lm_head_) \ DEF_PROPERTY(use_attention_linear_bias, use_attention_linear_bias_) \ DEF_PROPERTY(use_fp32_to_compute_logit, use_fp32_to_compute_logit_) \ DEF_PROPERTY(add_bias_linear, add_bias_linear_) \ DEF_PROPERTY(tokenizer_path, tokenizer_path_) \ DEF_PROPERTY(ckpt_path, ckpt_path_) \ DEF_PROPERTY(pre_seq_len, pre_seq_len_) \ DEF_PROPERTY(prefix_projection, prefix_projection_) \ DEF_PROPERTY(using_hf_sampling, using_hf_sampling_) \ DEF_PROPERTY(max_generate_batch_size, max_generate_batch_size_) \ DEF_PROPERTY(max_context_batch_size, max_context_batch_size_) \ DEF_PROPERTY(special_tokens, special_tokens_) \ DEF_PROPERTY(quant_algo, quant_algo_) \ DEF_PROPERTY(use_logn_attn, use_logn_attn_) \ DEF_PROPERTY(q_scaling, q_scaling_) \ DEF_PROPERTY(qk_norm, qk_norm_) \ DEF_PROPERTY(use_cross_attn, use_cross_attn_) \ DEF_PROPERTY(cross_attn_input_len, cross_attn_input_len_) \ DEF_PROPERTY(is_multimodal, is_multimodal_) \ DEF_PROPERTY(mm_sep_tokens, mm_sep_tokens_) \ DEF_PROPERTY(include_sep_tokens, include_sep_tokens_) \ DEF_PROPERTY(mm_position_ids_style, mm_position_ids_style_) \ DEF_PROPERTY(pre_allocate_op_mem, pre_allocate_op_mem_) \ DEF_PROPERTY(seq_size_per_block, seq_size_per_block_) \ DEF_PROPERTY(block_nums, block_nums_) \ DEF_PROPERTY(scheduler_reserve_resource_ratio, scheduler_reserve_resource_ratio_) \ DEF_PROPERTY(kv_cache_mem_mb, kv_cache_mem_mb_) \ DEF_PROPERTY(reserve_runtime_mem_mb, reserve_runtime_mem_mb_) \ DEF_PROPERTY(reuse_cache, reuse_cache_) \ DEF_PROPERTY(enable_partial_fallback, enable_partial_fallback_) \ DEF_PROPERTY(enable_fast_gen, enable_fast_gen_) \ DEF_PROPERTY(warm_up, warm_up_) \ DEF_PROPERTY(warm_up_with_loss, warm_up_with_loss_) \ DEF_PROPERTY(fast_gen_max_context_len, fast_gen_max_context_len_) \ DEF_PROPERTY(is_causal, is_causal_) \ DEF_PROPERTY(nccl_ip, nccl_ip_) \ DEF_PROPERTY(tp_nccl_port, tp_nccl_port_) \ DEF_PROPERTY(dp_tp_nccl_port, dp_tp_nccl_port_) \ DEF_PROPERTY(ffn_tp_nccl_port, ffn_tp_nccl_port_) \ DEF_PROPERTY(model_rpc_port, model_rpc_port_) \ DEF_PROPERTY(http_port, http_port_) \ DEF_PROPERTY(tp_size, tp_size_) \ DEF_PROPERTY(tp_rank, tp_rank_) \ DEF_PROPERTY(dp_size, dp_size_) \ DEF_PROPERTY(dp_rank, dp_rank_) \ DEF_PROPERTY(ffn_tp_size, ffn_tp_size_) \ DEF_PROPERTY(ffn_tp_rank, ffn_tp_rank_) \ DEF_PROPERTY(enable_sp, enable_sp_) \ DEF_PROPERTY(world_size, world_size_) \ DEF_PROPERTY(cache_store_listen_port, cache_store_listen_port_) \ DEF_PROPERTY(cache_store_connect_port, cache_store_connect_port_) \ DEF_PROPERTY(cache_store_rdma_connect_port, cache_store_rdma_connect_port_) \ DEF_PROPERTY(cache_store_rdma_listen_port, cache_store_rdma_listen_port_) \ DEF_PROPERTY(worker_port_offset, worker_port_offset_) \ DEF_PROPERTY(worker_addrs, worker_addrs_) \ DEF_PROPERTY(worker_grpc_addrs, worker_grpc_addrs_) \ DEF_PROPERTY(remote_rpc_server_port, remote_rpc_server_port_) \ DEF_PROPERTY(pd_separation, pd_separation_) \ DEF_PROPERTY(use_cache_store, use_cache_store_) \ DEF_PROPERTY(cache_store_rdma_mode, cache_store_rdma_mode_) \ DEF_PROPERTY(prefill_retry_times, prefill_retry_times_) \ DEF_PROPERTY(prefill_retry_timeout_ms, prefill_retry_timeout_ms_) \ DEF_PROPERTY(prefill_max_wait_timeout_ms, prefill_max_wait_timeout_ms_) \ DEF_PROPERTY(decode_retry_times, decode_retry_times_) \ DEF_PROPERTY(decode_retry_timeout_ms, decode_retry_timeout_ms_) \ DEF_PROPERTY(decode_polling_kv_cache_step_ms, decode_polling_kv_cache_step_ms_) \ DEF_PROPERTY(decode_use_async_load_cache, decode_use_async_load_cache_) \ DEF_PROPERTY(rdma_connect_retry_times, rdma_connect_retry_times_) \ DEF_PROPERTY(pd_sep_enable_fallback, pd_sep_enable_fallback_) \ DEF_PROPERTY(load_balance_policy_name, load_balance_policy_name_) \ DEF_PROPERTY(sync_status_interval_ms, sync_status_interval_ms_) \ DEF_PROPERTY(load_cache_timeout_ms, load_cache_timeout_ms_) \ DEF_PROPERTY(max_rpc_timeout_ms, max_rpc_timeout_ms_) \ DEF_PROPERTY(ep_size, ep_size_) \ DEF_PROPERTY(ep_rank, ep_rank_) \ DEF_PROPERTY(use_kvcache, use_kvcache_) \ DEF_PROPERTY(use_expert_attention, use_expert_attention_) \ DEF_PROPERTY(local_rank, local_rank_) \ DEF_PROPERTY(rotary_embedding_mscale, rotary_embedding_mscale_) \ DEF_PROPERTY(rotary_embedding_offset, rotary_embedding_offset_) \ DEF_PROPERTY(use_mla, use_mla_) \ DEF_PROPERTY(mla_ops_type, mla_ops_type_) \ DEF_PROPERTY(q_lora_rank, q_lora_rank_) \ DEF_PROPERTY(kv_lora_rank, kv_lora_rank_) \ DEF_PROPERTY(nope_head_dim, nope_head_dim_) \ DEF_PROPERTY(rope_head_dim, rope_head_dim_) \ DEF_PROPERTY(v_head_dim, v_head_dim_) \ DEF_PROPERTY(moe_n_group, moe_n_group_) \ DEF_PROPERTY(moe_topk_group, moe_topk_group_) \ DEF_PROPERTY(softmax_extra_scale, softmax_extra_scale_) \ DEF_PROPERTY(vit_separation, vit_separation_) \ DEF_PROPERTY(enable_speculative_decoding, enable_speculative_decoding_) \ DEF_PROPERTY(model_name, model_name_) \ DEF_PROPERTY(deepseek_rope_mscale, deepseek_rope_mscale_) \ DEF_PROPERTY(deepseek_mscale_all_dim, deepseek_mscale_all_dim_) \ DEF_PROPERTY(reverse_e_h_norm, reverse_e_h_norm_) \ DEF_PROPERTY(enable_eplb, enable_eplb_) \ DEF_PROPERTY(phy_exp_num, phy_exp_num_) \ DEF_PROPERTY(eplb_update_time, eplb_update_time_) \ DEF_PROPERTY(eplb_mode, eplb_mode_) \ DEF_PROPERTY(py_eplb, py_eplb_) pybind11::class_<GptInitParameter>(m, "GptInitParameter") .def(pybind11::init<int64_t, // head_num int64_t, // size_per_head int64_t, // num_layers int64_t, // max_seq_len int64_t, // vocab_size int64_t // hidden_size >(), py::arg("head_num"), py::arg("size_per_head"), py::arg("num_layers"), py::arg("max_seq_len"), py::arg("vocab_size"), py::arg("hidden_size")) .def("insertMultiTaskPromptTokens", &GptInitParameter::insertMultiTaskPromptTokens, py::arg("task_id"), py::arg("tokens_id")) .def("setLayerNormType", &GptInitParameter::setLayerNormType) .def("setNormType", &GptInitParameter::setNormType) .def("setActivationType", &GptInitParameter::setActivationType) .def("setTaskType", &GptInitParameter::setTaskType, py::arg("task")) .def("setKvCacheDataType", &GptInitParameter::setKvCacheDataType) .def("isGatedActivation", &GptInitParameter::isGatedActivation) .def("isKvCacheQuant", &GptInitParameter::isKvCacheQuant) REGISTER_PROPERTYS; } }