maga_transformer/cpp/dataclass/GenerateConfig.h (175 lines of code) (raw):

#pragma once #include <cstdint> #include <optional> #include <sstream> #include <string> #include <vector> #include "maga_transformer/cpp/utils/StringUtil.h" #include "maga_transformer/cpp/tokenizer/Tokenizer.h" #include "maga_transformer/cpp/th_op/GptInitParameter.h" #include "autil/legacy/jsonizable.h" namespace rtp_llm { // NOTE: The params in generate config should be splitted into two parts: // 1. The params that can be different for a single sampler. // e.g. top_k, top_p, temperature, repetition_penalty, etc. // 2. The params that must be the same for a single sampler. // e.g. beam_size, max_seq_len, etc. // For the second part, different samplers should be created for different params. // So they can not be batched together for now. class GenerateConfig : public autil::legacy::Jsonizable { public: int global_request_id = -1; int max_new_tokens = 8192; int min_new_tokens = 0; int num_validate_token = 0; // for speculative decoding validation. int num_beams = 1; int num_return_sequences = 1; int top_k = 0; float top_p = 1.0; float temperature = 1.0; float repetition_penalty = 1.0; std::optional<int> no_repeat_ngram_size; std::optional<int> random_seed; std::optional<float> top_p_decay; std::optional<float> top_p_min; std::optional<int> top_p_reset_ids; std::optional<std::string> task_id; std::string adapter_name = ""; std::vector<std::string> adapter_names; std::vector<int> select_tokens_id; std::vector<std::string> select_tokens_str; int calculate_loss = 0; bool return_logits = false; bool return_cum_log_probs = false; bool return_incremental = false; bool return_hidden_states = false; bool return_output_ids = false; bool return_input_ids = false; bool is_streaming = false; int timeout_ms = -1; bool sp_edit = false; bool force_disable_sp_run = false; bool return_all_probs = false; bool return_softmax_probs = false; std::vector<std::vector<int>> stop_words_list; std::vector<std::string> stop_words_str; bool print_stop_words = false; std::string sp_advice_prompt; std::vector<int> sp_advice_prompt_token_ids; bool can_use_pd_separation = true; bool pd_separation = false; bool in_think_mode = false; int max_thinking_tokens = 0; std::vector<int> end_think_token_ids; bool gen_timeline = false; bool top1() { return top_k == 1; } void addSpecialTokens(const rtp_llm::SpecialTokens& special_tokens) { for (const auto& vec : special_tokens.stop_words_id_list_) { std::vector<int> tmpVec; for (int64_t val: vec) { tmpVec.push_back(static_cast<int>(val)); } stop_words_list.push_back(tmpVec); } const auto& vec = special_tokens.stop_words_str_list_; stop_words_str.insert(stop_words_str.begin(), vec.begin(), vec.end()); } void convertSelectTokens(int vocab_size, std::shared_ptr<Tokenizer> tokenizer) { for (const auto& token_str: select_tokens_str) { auto vec = tokenizer->encode(token_str); select_tokens_id.insert(select_tokens_id.begin(), vec.begin(), vec.end()); } auto areTokensValid = [](const std::vector<int>& select_tokens_id, int vocab_size) { return std::all_of(select_tokens_id.begin(), select_tokens_id.end(), [vocab_size](int token_id) { return token_id < vocab_size && token_id >= 0; }); }; if (!areTokensValid(select_tokens_id, vocab_size)) { throw std::runtime_error("token_id should be less than vocab_size"); } } std::string debugString() const { std::stringstream debug_string; debug_string << "GenerateConfig {" << "max_new_tokens:" << max_new_tokens << ", min_new_tokens:" << min_new_tokens << ", num_beams:" << num_beams << ", num_return_sequences:" << num_return_sequences << ", calculate_loss:" << calculate_loss << ", return_logits:" << return_logits << ", return_incremental: " << return_incremental << ", return_hidden_states:" << return_hidden_states << ", return_output_ids:" << return_output_ids << ", return_input_ids:" << return_input_ids << ", is_streaming:" << is_streaming << ", timeout_ms:" << timeout_ms << ", top_k:" << top_k << ", top_p:" << top_p << ", force_disable_sp_run: " << force_disable_sp_run << ", return_all_probs: " << return_all_probs << ", stop_words_list:" << vectorsToString(stop_words_list) << ", can_use_pd_separation: " << can_use_pd_separation << ", pd_separation: " << pd_separation << ", in_think_mode: " << in_think_mode << ", max_thinking_tokens: " << max_thinking_tokens << ", end_think_token_ids: " << vectorToString(end_think_token_ids) << "}"; return debug_string.str(); } void Jsonize(autil::legacy::Jsonizable::JsonWrapper& json) override { #define JSONIZE(field) json.Jsonize(#field, field, field) // used for de-serialization #define JSONIZE_OPTIONAL(field) try { \ using Type = decltype(field)::value_type; \ Type field##Tmp; \ json.Jsonize(#field, field##Tmp); \ field = field##Tmp; \ } catch (autil::legacy::ExceptionBase &e) { \ if (field.has_value() == false) { \ field = std::nullopt; \ } \ } JSONIZE(max_new_tokens); JSONIZE(min_new_tokens); JSONIZE(num_validate_token); JSONIZE(num_beams); JSONIZE(num_return_sequences); JSONIZE(top_k); JSONIZE(top_p); JSONIZE(temperature); JSONIZE(repetition_penalty); JSONIZE_OPTIONAL(no_repeat_ngram_size); JSONIZE_OPTIONAL(random_seed); JSONIZE_OPTIONAL(top_p_decay); JSONIZE_OPTIONAL(top_p_min); JSONIZE_OPTIONAL(top_p_reset_ids); JSONIZE_OPTIONAL(task_id); try { std::string adapter_name_; json.Jsonize("adapter_name", adapter_name_); adapter_name = adapter_name_; } catch (autil::legacy::ExceptionBase &e) { try { std::vector<std::string> adapter_names_; json.Jsonize("adapter_name", adapter_names_); adapter_names = adapter_names_; } catch (autil::legacy::ExceptionBase &e) { // noop } } JSONIZE(select_tokens_id); JSONIZE(select_tokens_str); JSONIZE(calculate_loss); JSONIZE(return_logits); JSONIZE(return_incremental); JSONIZE(return_hidden_states); JSONIZE(return_output_ids); JSONIZE(return_input_ids); JSONIZE(is_streaming); JSONIZE(timeout_ms); JSONIZE(stop_words_list); JSONIZE(stop_words_str); JSONIZE(print_stop_words); JSONIZE(sp_edit); JSONIZE(force_disable_sp_run); JSONIZE(return_all_probs); JSONIZE(sp_advice_prompt); JSONIZE(sp_advice_prompt_token_ids); JSONIZE(in_think_mode); JSONIZE(max_thinking_tokens); JSONIZE(end_think_token_ids); #undef JSONIZE #undef JSONIZE_OPTIONAL } }; } // namespace rtp_llm