maga_transformer/config/generate_config.py (182 lines of code) (raw):
import copy
import hashlib
from pydantic import BaseModel
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Optional, Union
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from maga_transformer.utils.util import check_with_info
from maga_transformer.utils.check_util import *
from maga_transformer.config.exceptions import FtRuntimeException, ExceptionType
class RequestFormat:
RAW = 'raw'
CHAT_API = 'chatapi'
class GenerateConfig(BaseModel):
max_new_tokens: int = 1000
# only for qwen agent fncall check max input tokens
max_input_tokens: int = 32000
max_thinking_tokens: int = 32000
end_think_token_ids: List[int] = []
in_think_mode: bool = False
num_beams: int = 1
# 0 mean not use num_return_sequences,
# whether to enable num_return_sequences, the output format of the results is inconsistent.
num_return_sequences: int = 0
top_k: Union[List[int], int] = 0
top_p: Union[List[float], float] = 1.0
temperature: Union[List[float], float] = 1.0
repetition_penalty: Union[List[float], float] = 1.0
min_new_tokens: Union[List[int], int] = 0
no_repeat_ngram_size: Optional[Union[List[int], int]] = None
random_seed: Optional[Union[List[int], int]] = None
top_p_decay: Optional[Union[List[float], float]] = None
top_p_min: Optional[Union[List[float], float]] = None
top_p_reset_ids: Optional[Union[List[int],int]] = None
stop_words_str: List[str] = []
stop_words_list: List[List[int]] = []
bad_words_list: Optional[Union[List[List[List[int]]], List[List[int]]]] = None
eos_token_id: Optional[Union[List[int],int]] = None
pad_token_id: Optional[Union[List[int],int]] = None
bos_token_id: Optional[Union[List[int],int]] = None
using_hf_sampling: bool = False
print_stop_words: bool = False
timeout_ms: Optional[int] = -1
chat_id: Optional[str] = None
task_id: Optional[Union[str, int]] = None
request_format: str = RequestFormat.RAW
# calculate_loss style: 0 for not calculate; 1 for sum; 2 for each token
calculate_loss: int = 0
return_logits: bool = False
return_incremental: bool = False
return_hidden_states: bool = False
select_tokens_str: List[str] = []
select_tokens_id: List[int] = []
return_input_ids: bool = False
return_output_ids: bool = False
md5_value: str = ""
custom_prop: str = "{}"
sp_advice_prompt: str = ""
sp_advice_prompt_token_ids: List[int] = []
sp_edit: bool = False
force_disable_sp_run: bool = False
return_cum_log_probs: bool = False
return_all_probs: bool = False
return_softmax_probs: bool = False
can_use_pd_separation: bool = True
gen_timeline: bool = False
# lora
adapter_name: Optional[Union[str, List[str]]] = None
is_streaming: bool = False
# 是否允许tool_call专用的标签如<tool_call>作为content传出, 优化tool_call失败时的用户体验
tool_call_message_extract_strategy: str = "default" # default/skip_on_failure
global_request_id: int = -1
def gen_hash_value(self):
cp = copy.copy(self)
cp.max_new_tokens = 0
cp.chat_id = None
cp.random_seed = None
cp.md5_value = ""
cp.timeout_ms = -1
self.md5_value = hashlib.md5(cp.__str__().encode()).hexdigest()
def is_same(self, config: 'GenerateConfig') -> bool:
return self.md5_value == config.md5_value
def update(self, new: Dict[str, Any]):
for key, value in new.items():
if hasattr(self, key):
setattr(self, key, value)
def update_and_pop(self, new: Dict[str, Any]):
to_remove: List[str] = []
for key, value in new.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
return {k: v for k, v in new.items() if k not in to_remove}
@staticmethod
def create_generate_config(generate_config: Dict[str, Any], **kwargs: Any) -> 'GenerateConfig':
generate_config.update(kwargs)
try:
config = GenerateConfig(**generate_config)
except Exception as e:
raise FtRuntimeException(ExceptionType.ERROR_GENERATE_CONFIG_FORMAT, f"generate_config validate failed: {str(e)}")
config.validate()
return config
def convert_select_tokens(self, vocab_size, tokenizer):
for token_str in self.select_tokens_str:
self.select_tokens_id += tokenizer.encode(token_str)
if not all(token_id < vocab_size and token_id >= 0 for token_id in self.select_tokens_id):
raise FtRuntimeException(ExceptionType.ERROR_INPUT_FORMAT_ERROR,
f"token_id in select_tokens_id {self.select_tokens_id} should be less than vocab_size {vocab_size}, and shoud not be negative")
def add_special_tokens(self, special_tokens: Any):
# 这里假设外部传进来的stop_word_list和stop_word_str都不包含batch维度
self.stop_words_list += special_tokens.stop_words_id_list
self.stop_words_str += special_tokens.stop_words_str_list
def add_thinking_params(self, tokenizer):
end_think_token_id = int(os.environ.get("THINK_END_TOKEN_ID", "-1"))
self.end_think_token_ids = [end_think_token_id] if end_think_token_id != -1 else []
if bool(int(os.environ.get("THINK_MODE", 0))) and tokenizer and end_think_token_id == -1:
think_end_tag: str = os.environ.get("THINK_END_TAG", "</think>\n\n").encode('utf-8').decode('unicode_escape')
if isinstance(tokenizer, PreTrainedTokenizerBase):
tokenized_result: List[int] = tokenizer.encode(think_end_tag, add_special_tokens=False)
else:
tokenized_result: List[int] = tokenizer.encode(think_end_tag)
self.end_think_token_ids = tokenized_result
self.in_think_mode = bool(int(os.environ.get("THINK_MODE", 0))) and len(self.end_think_token_ids) >= 0
def add_stop_ids_from_str(self, tokenizer):
ids_list = []
for word in self.stop_words_str:
if isinstance(tokenizer, PreTrainedTokenizerBase):
token_id = tokenizer.convert_tokens_to_ids(word)
if isinstance(token_id, int):
ids_list.append([token_id])
elif isinstance(token_id, list):
ids_list.append(token_id)
else:
ids_list.append(tokenizer.encode(word, add_special_tokens=True))
elif tokenizer is None:
return
else:
ids_list.append(tokenizer.encode(word))
# remove duplicate element
for item in ids_list:
if item not in self.stop_words_list:
self.stop_words_list.append(item)
def validate(self):
try:
check_with_info(is_union_positive_integer(self.top_k), \
f"top_k {self.top_k} is wrong data type")
check_with_info(is_union_positive_number(self.top_p), \
f"top_p {self.top_p} is wrong data type")
check_with_info(is_union_positive_integer(self.min_new_tokens), \
f"min_new_tokens {self.min_new_tokens} is wrong data type")
check_with_info(is_union_positive_number(self.repetition_penalty), \
f"repetition_penalty {self.repetition_penalty} is wrong data type")
check_with_info(is_positive_integer(self.max_new_tokens), \
f"max_new_tokens {self.max_new_tokens} is wrong data type")
check_with_info(is_positive_integer(self.num_beams), \
f"num_beams {self.num_beams} is wrong data type")
check_with_info(is_positive_integer(self.num_return_sequences), \
f"num_return_sequences {self.num_return_sequences} is wrong data type")
check_with_info(is_union_positive_number(self.temperature), \
f"temperature {self.temperature} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.no_repeat_ngram_size), \
f"no_repeat_ngram_size {self.no_repeat_ngram_size} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.random_seed), \
f"random_seed {self.random_seed} is wrong data type")
check_with_info(check_optional(is_union_positive_number, self.top_p_decay),
f"top_p_decay {self.top_p_decay} is wrong data type")
check_with_info(check_optional(is_union_positive_number, self.top_p_min), \
f"top_p_min {self.top_p_min} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.top_p_reset_ids), \
f"top_p_reset_ids {self.top_p_reset_ids} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.eos_token_id), \
f"eos_token_id {self.eos_token_id} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.pad_token_id), \
f"pad_token_id {self.pad_token_id} is wrong data type")
check_with_info(check_optional(is_union_positive_integer, self.bos_token_id), \
f"bos_token_id {self.bos_token_id} is wrong data type")
check_with_info(is_list_positive_integer_list(self.stop_words_list), \
f"stop_words_list {self.stop_words_list} is wrong data type")
check_with_info(is_union_positive_integer(self.sp_advice_prompt_token_ids),
f"sp_advice_prompt_token_ids {self.sp_advice_prompt_token_ids} is wrong data type")
if self.in_think_mode:
check_with_info(is_positive_integer(self.max_thinking_tokens), \
f"max_thinking_tokens {self.max_thinking_tokens} is wrong data type")
check_with_info(is_list_positive_integer(self.end_think_token_ids), \
f"end_think_token_ids {self.end_think_token_ids} is wrong data type")
calculate_loss_list = [0, 1, 2]
check_with_info(self.calculate_loss in calculate_loss_list, \
f"calculate_loss {self.top_k} in generate_config can only be in {calculate_loss_list}," \
" but it's {self.calculate_loss}")
except Exception as e:
raise FtRuntimeException(ExceptionType.ERROR_INPUT_FORMAT_ERROR, str(e))