in maga_transformer/config/gpt_init_model_parameters.py [0:0]
def update_common(self,
ckpt_path: str,
lora_infos: Optional[Dict[str, str]],
ptuning_path: Optional[str],
tokenizer_path: str,
int8_mode: bool,
data_type: WEIGHT_TYPE,
max_seq_len: int,
seq_size_per_block: int,
gen_num_per_circle: int,
ref_module: Optional[torch.nn.Module] = None,
ref_dict: Dict[str, torch.Tensor] = {},
parallel_info: ParallelInfo=g_parallel_info,
config_mode: ConfigMode = ConfigMode.ComplexMode,
gang_info: Optional[GangInfo] = None):
self.tp_size = parallel_info.tp_size
self.tp_rank = parallel_info.tp_rank
self.ep_size = parallel_info.ep_size
self.ep_rank = parallel_info.ep_rank
self.dp_size = parallel_info.dp_size
self.dp_rank = parallel_info.dp_rank
self.ffn_tp_rank = parallel_info.ffn_tp_rank
self.ffn_tp_size = parallel_info.ffn_tp_size
self.enable_sp = parallel_info.ffn_sp_size > 1
self.local_rank = parallel_info.local_rank
self.eplb_update_time = int(os.environ.get("EPLB_UPDATE_TIME", 5000))
self.eplb_mode = EplbMode.__members__[os.environ.get('EPLB_MODE', 'NONE')]
self.phy_exp_num = int(os.environ.get("REDUNDANT_EXPERT", 0)) + self.expert_num
self.enable_merge_w13 = os.getenv('ENABLE_MERGE_W13', '0').lower() == '1'
logging.info(f"phy_exp_num: {self.phy_exp_num}, use merge w13: {self.enable_merge_w13}")
if gang_info is not None:
self.num_nodes = gang_info.num_nodes
else:
try:
self.num_nodes = get_gang_info().num_nodes
except:
self.num_nodes = 1
self.ckpt_path = ckpt_path
self.lora_infos = lora_infos
self.tokenizer_path = tokenizer_path
if not self.quant_algo.isQuant() and int8_mode:
self.quant_algo.setQuantAlgo("weight_only_per_col", 8, 0)
self.data_type = data_type.to_str()
self.gen_num_per_circle = gen_num_per_circle
self.ptuning_path = ptuning_path
self.ref_module = ref_module
self.ref_dict = ref_dict
if max_seq_len != 0:
self.max_seq_len = max_seq_len
if self.max_seq_len < 1:
self.max_seq_len = 1024
logging.info(f'max_seq_len: {self.max_seq_len}')
self.update_task_type_use_kvcache()
logging.info(f"config_mode = {config_mode}")
if config_mode == ConfigMode.SimpleMode:
return
self.update_worker_addrs()
self.update_config_with_sparse_config(ckpt_path)
self.update_inter_padding_size(self.tp_size, self.ep_size, self.dp_size)
self.update_task_prompt_config()
self.update_weight_style(ckpt_path)
load_cutlass_gemm_config(self.quant_algo)
hack_layer_num = int(os.environ.get('HACK_LAYER_NUM', 0))
if (hack_layer_num):
logging.info(f"hack layernum to {hack_layer_num}")
self.layer_num = hack_layer_num
self.seq_size_per_block = closest_power_of_2(int(max(seq_size_per_block, self.max_seq_len // 128))) # must be 2^n
self.seq_size_per_block = int(os.environ.get('SEQ_SIZE_PER_BLOCK', self.seq_size_per_block))
logging.info(f'seq_size_per_block: {self.seq_size_per_block}')
self.max_generate_batch_size = int(os.environ.get('CONCURRENCY_LIMIT', 128))
logging.info(f'max_generate_batch_size: {self.max_generate_batch_size}')
self.max_context_batch_size = int(os.environ.get('MAX_CONTEXT_BATCH_SIZE', 1))
logging.info(f'max_context_batch_size: {self.max_context_batch_size}')
self.reserve_runtime_mem_mb = int(os.environ.get('RESERVER_RUNTIME_MEM_MB', 128))
logging.info(f'reserve_runtime_mem_mb: {self.reserve_runtime_mem_mb}')
self.kv_cache_mem_mb = int(os.environ.get('KV_CACHE_MEM_MB', -1))
logging.info(f'kv_cache_mem_mb: {self.kv_cache_mem_mb}')
self.block_nums = int(os.environ.get('TEST_BLOCK_NUM', 0))
logging.info(f'block_nums: {self.block_nums}')
if os.environ.get('TEST_LAYER_NUM'):
logging.info(f'replace model layer with TEST_LAYER_NUM: {os.environ.get("TEST_LAYER_NUM")}')
self.layer_num = int(os.environ.get('TEST_LAYER_NUM', self.layer_num))
self.enable_partial_fallback = bool(int(os.environ.get('ENABLE_PARTIAL_FALLBACK', 0)))
logging.info(f'enable_partial_fallback: {self.enable_partial_fallback}')
self.enable_fast_gen = bool(int(os.environ.get('ENABLE_FAST_GEN', 0)))
logging.info(f'enable_fast_gen: {self.enable_fast_gen}')
self.warm_up = bool(int(os.environ.get('WARM_UP', 1)))
logging.info(f'warm_up: {self.warm_up}')
self.warm_up_with_loss = bool(int(os.environ.get('WARM_UP_WITH_LOSS', 0)))
logging.info(f'warm_up_with_loss: {self.warm_up_with_loss}')
self.vit_separation = int(os.environ.get('VIT_SEPARATION', 0))
logging.info(f'vit_separation: {self.vit_separation}')
self.fast_gen_max_context_len = int(os.environ.get('FAST_GEN_MAX_CONTEXT_LEN', 1024))
logging.info(f'fast_gen_max_context_len: {self.fast_gen_max_context_len}')
self.max_rpc_timeout_ms = int(os.environ.get('MAX_RPC_TIMEOUT_MS', 0))
logging.info(f'max_rpc_timeout_ms: {self.max_rpc_timeout_ms}')
self.pd_separation = bool(int(os.environ.get('PD_SEPARATION', 0)))
logging.info(f'pd_separation: {self.pd_separation}')
if self.pd_separation:
self.prefill_retry_times = int(os.environ.get('PREFILL_RETRY_TIMES', 0))
logging.info(f'prefill_retry_times: {self.prefill_retry_times}')
self.prefill_retry_timeout_ms = int(os.environ.get('PREFILL_RETRY_TIMEOUT_MS', 0))
logging.info(f'prefill_retry_timeout_ms: {self.prefill_retry_timeout_ms}')
self.prefill_max_wait_timeout_ms = int(os.environ.get('PREFILL_MAX_WAIT_TIMEOUT_US', 600 * 1000 * 1000))
logging.info(f'prefill_max_wait_timeout_ms: {self.prefill_max_wait_timeout_ms}')
self.pd_sep_enable_fallback = bool(int(os.environ.get('PD_SEP_ENABLE_FALLBACK', 0)))
logging.info(f'pd_sep_enable_fallback: {self.pd_sep_enable_fallback}')
self.load_balance_policy_name = os.environ.get('LOAD_BALANCE_POLICY_NAME', "RR")
logging.info(f'load_balance_policy_name: {self.load_balance_policy_name}')
policy_list = ["RR", "WRR"]
if not self.load_balance_policy_name in policy_list:
raise Exception(f"load_balance_policy_name {self.load_balance_policy_name} " \
f"is not right, it must in {policy_list}")
self.sync_status_interval_ms = int(os.environ.get('SYNC_STATUS_INTERVAL_MS', 50))
logging.info(f'sync_status_interval_ms: {self.sync_status_interval_ms}')
self.use_cache_store = bool(int(os.environ.get('USE_CACHE_STORE', 0)))
logging.info(f'use_cache_store: {self.use_cache_store}')
if self.use_cache_store:
self.cache_store_rdma_mode = bool(int(os.environ.get('CACHE_STORE_RDMA_MODE', 1)))
logging.info(f'cache_store_rdma_mode: {self.cache_store_rdma_mode}')
self.load_cache_timeout_ms = int(os.environ.get('LOAD_CACHE_TIMEOUT_MS', 0))
logging.info(f'load_cache_timeout_ms: {self.load_cache_timeout_ms}')
self.decode_retry_times = int(os.environ.get('DECODE_RETRY_TIMES', 0))
logging.info(f'decode_retry_times: {self.prefill_retry_times}')
self.decode_retry_timeout_ms = int(os.environ.get('DECODE_RETRY_TIMEOUT_MS', 0))
logging.info(f'decode_retry_timeout_ms: {self.decode_retry_timeout_ms}')
self.rdma_connect_retry_times = int(os.environ.get('RDMA_CONNECT_RETRY_TIMES', 0))
logging.info(f'rdma_connect_retry_times: {self.rdma_connect_retry_times}')
self.decode_polling_kv_cache_step_ms = int(os.environ.get('DECODE_POLLING_KV_CACHE_STEP_MS', 30))
logging.info(f'decode_polling_kv_cache_step_ms: {self.decode_polling_kv_cache_step_ms}')
self.decode_use_async_load_cache = bool(int(os.environ.get('DECODE_USE_ASYNC_LOAD_CACHE', 1)))
logging.info(f'decode_use_async_load_cache: {self.decode_use_async_load_cache}')
self.scheduler_reserve_resource_ratio = int(os.environ.get('SCHEDUlER_RESERVE_RESOURCE_RATIO', 5))
logging.info(f'scheduler_reserve_resource_ratio: {self.scheduler_reserve_resource_ratio}')
self.reuse_cache = os.environ.get('REUSE_CACHE', None) == '1' or os.environ.get('USE_BLOCK_CACHE', None) == '1'
logging.info(f'reuse_cache: {self.reuse_cache}')
self.pre_allocate_op_mem = bool(int(os.environ.get('PRE_ALLOCATE_OP_MEM', 1)))
logging.info(f'pre_allocate_op_mem: {self.pre_allocate_op_mem}')
if bool(int(os.environ.get('INT8_KV_CACHE', 0))):
self.kv_cache_data_type = WEIGHT_TYPE.INT8.to_str()
elif self.quant_algo.isFp8() and not self.quant_algo.isGroupwise():
self.kv_cache_data_type = WEIGHT_TYPE.FP8.to_str()
else:
self.kv_cache_data_type = self.data_type
logging.info(f'kv_cache_data_type: {self.kv_cache_data_type}')
logging.info(f'tp_split_emb_and_lm_head: {self.tp_split_emb_and_lm_head}')
# use environment variables to update stop_words_str and stop_words_id
env_stop_words_str = os.environ.get('STOP_WORDS_STR', None)
env_stop_words_id = os.environ.get('STOP_WORDS_LIST', None)
env_stop_words_str_list = json.loads(env_stop_words_str) if env_stop_words_str else []
env_stop_words_id_list = json.loads(env_stop_words_id) if env_stop_words_id else []
env_force_stop = os.environ.get('FORCE_STOP_WORDS', None)
if env_force_stop and str_to_bool(env_force_stop):
self.special_tokens.stop_words_str_list = env_stop_words_str_list
self.special_tokens.stop_words_id_list = env_stop_words_id_list
else:
self.special_tokens.stop_words_str_list = self.special_tokens.stop_words_str_list + env_stop_words_str_list
self.special_tokens.stop_words_id_list = self.special_tokens.stop_words_id_list + env_stop_words_id_list
logging.info(f"use stop_words_str_list [{self.special_tokens.stop_words_str_list }]," \
f" stop_words_id_list [{self.special_tokens.stop_words_id_list}]")