def update_common()

in maga_transformer/config/gpt_init_model_parameters.py [0:0]


    def update_common(self,
                      ckpt_path: str,
                      lora_infos: Optional[Dict[str, str]],
                      ptuning_path: Optional[str],
                      tokenizer_path: str,
                      int8_mode: bool,
                      data_type: WEIGHT_TYPE,
                      max_seq_len: int,
                      seq_size_per_block: int,
                      gen_num_per_circle: int,
                      ref_module: Optional[torch.nn.Module] = None,
                      ref_dict: Dict[str, torch.Tensor] = {},
                      parallel_info: ParallelInfo=g_parallel_info,
                      config_mode: ConfigMode = ConfigMode.ComplexMode,
                      gang_info: Optional[GangInfo] = None):
        self.tp_size = parallel_info.tp_size
        self.tp_rank = parallel_info.tp_rank
        self.ep_size = parallel_info.ep_size
        self.ep_rank = parallel_info.ep_rank
        self.dp_size = parallel_info.dp_size
        self.dp_rank = parallel_info.dp_rank
        self.ffn_tp_rank = parallel_info.ffn_tp_rank
        self.ffn_tp_size = parallel_info.ffn_tp_size
        self.enable_sp = parallel_info.ffn_sp_size > 1
        self.local_rank = parallel_info.local_rank

        self.eplb_update_time = int(os.environ.get("EPLB_UPDATE_TIME", 5000))
        self.eplb_mode = EplbMode.__members__[os.environ.get('EPLB_MODE', 'NONE')]

        self.phy_exp_num = int(os.environ.get("REDUNDANT_EXPERT", 0)) + self.expert_num
        self.enable_merge_w13 = os.getenv('ENABLE_MERGE_W13', '0').lower() == '1'
        logging.info(f"phy_exp_num: {self.phy_exp_num}, use merge w13: {self.enable_merge_w13}")
        
        if gang_info is not None:
            self.num_nodes = gang_info.num_nodes
        else:
            try:
                self.num_nodes = get_gang_info().num_nodes
            except:
                self.num_nodes = 1
            

        self.ckpt_path = ckpt_path
        self.lora_infos = lora_infos
        self.tokenizer_path = tokenizer_path
        if not self.quant_algo.isQuant() and int8_mode:
            self.quant_algo.setQuantAlgo("weight_only_per_col", 8, 0)
        self.data_type = data_type.to_str()
        self.gen_num_per_circle = gen_num_per_circle
        self.ptuning_path = ptuning_path
        self.ref_module = ref_module
        self.ref_dict = ref_dict
        if max_seq_len != 0:
            self.max_seq_len = max_seq_len
        if self.max_seq_len < 1:
            self.max_seq_len = 1024
        logging.info(f'max_seq_len: {self.max_seq_len}')

        self.update_task_type_use_kvcache()

        logging.info(f"config_mode = {config_mode}")
        if config_mode == ConfigMode.SimpleMode:
            return

        self.update_worker_addrs()
        self.update_config_with_sparse_config(ckpt_path)
        self.update_inter_padding_size(self.tp_size, self.ep_size, self.dp_size)
        self.update_task_prompt_config()

        self.update_weight_style(ckpt_path)

        load_cutlass_gemm_config(self.quant_algo)

        hack_layer_num = int(os.environ.get('HACK_LAYER_NUM', 0))
        if (hack_layer_num):
            logging.info(f"hack layernum to {hack_layer_num}")
            self.layer_num = hack_layer_num

        self.seq_size_per_block = closest_power_of_2(int(max(seq_size_per_block, self.max_seq_len // 128))) # must be 2^n
        self.seq_size_per_block = int(os.environ.get('SEQ_SIZE_PER_BLOCK', self.seq_size_per_block))
        logging.info(f'seq_size_per_block: {self.seq_size_per_block}')
        self.max_generate_batch_size = int(os.environ.get('CONCURRENCY_LIMIT', 128))
        logging.info(f'max_generate_batch_size: {self.max_generate_batch_size}')
        self.max_context_batch_size = int(os.environ.get('MAX_CONTEXT_BATCH_SIZE', 1))
        logging.info(f'max_context_batch_size: {self.max_context_batch_size}')
        self.reserve_runtime_mem_mb = int(os.environ.get('RESERVER_RUNTIME_MEM_MB', 128))
        logging.info(f'reserve_runtime_mem_mb: {self.reserve_runtime_mem_mb}')
        self.kv_cache_mem_mb = int(os.environ.get('KV_CACHE_MEM_MB', -1))
        logging.info(f'kv_cache_mem_mb: {self.kv_cache_mem_mb}')
        self.block_nums = int(os.environ.get('TEST_BLOCK_NUM', 0))
        logging.info(f'block_nums: {self.block_nums}')
        if os.environ.get('TEST_LAYER_NUM'):
            logging.info(f'replace model layer with TEST_LAYER_NUM: {os.environ.get("TEST_LAYER_NUM")}')
            self.layer_num = int(os.environ.get('TEST_LAYER_NUM', self.layer_num))
        self.enable_partial_fallback = bool(int(os.environ.get('ENABLE_PARTIAL_FALLBACK', 0)))
        logging.info(f'enable_partial_fallback: {self.enable_partial_fallback}')
        self.enable_fast_gen = bool(int(os.environ.get('ENABLE_FAST_GEN', 0)))
        logging.info(f'enable_fast_gen: {self.enable_fast_gen}')
        self.warm_up = bool(int(os.environ.get('WARM_UP', 1)))
        logging.info(f'warm_up: {self.warm_up}')
        self.warm_up_with_loss = bool(int(os.environ.get('WARM_UP_WITH_LOSS', 0)))
        logging.info(f'warm_up_with_loss: {self.warm_up_with_loss}')

        self.vit_separation = int(os.environ.get('VIT_SEPARATION', 0))
        logging.info(f'vit_separation: {self.vit_separation}')

        self.fast_gen_max_context_len = int(os.environ.get('FAST_GEN_MAX_CONTEXT_LEN', 1024))
        logging.info(f'fast_gen_max_context_len: {self.fast_gen_max_context_len}')

        self.max_rpc_timeout_ms = int(os.environ.get('MAX_RPC_TIMEOUT_MS', 0))
        logging.info(f'max_rpc_timeout_ms: {self.max_rpc_timeout_ms}')

        self.pd_separation = bool(int(os.environ.get('PD_SEPARATION', 0)))
        logging.info(f'pd_separation: {self.pd_separation}')
        if self.pd_separation:
            self.prefill_retry_times = int(os.environ.get('PREFILL_RETRY_TIMES', 0))
            logging.info(f'prefill_retry_times: {self.prefill_retry_times}')
            self.prefill_retry_timeout_ms = int(os.environ.get('PREFILL_RETRY_TIMEOUT_MS', 0))
            logging.info(f'prefill_retry_timeout_ms: {self.prefill_retry_timeout_ms}')
            self.prefill_max_wait_timeout_ms = int(os.environ.get('PREFILL_MAX_WAIT_TIMEOUT_US', 600 * 1000 * 1000))
            logging.info(f'prefill_max_wait_timeout_ms: {self.prefill_max_wait_timeout_ms}')
            self.pd_sep_enable_fallback = bool(int(os.environ.get('PD_SEP_ENABLE_FALLBACK', 0)))
            logging.info(f'pd_sep_enable_fallback: {self.pd_sep_enable_fallback}')
            self.load_balance_policy_name = os.environ.get('LOAD_BALANCE_POLICY_NAME', "RR")
            logging.info(f'load_balance_policy_name: {self.load_balance_policy_name}')
            policy_list = ["RR", "WRR"]
            if not self.load_balance_policy_name in policy_list:
                raise Exception(f"load_balance_policy_name {self.load_balance_policy_name} " \
                    f"is not right, it must in {policy_list}")
            self.sync_status_interval_ms = int(os.environ.get('SYNC_STATUS_INTERVAL_MS', 50))
            logging.info(f'sync_status_interval_ms: {self.sync_status_interval_ms}')

        self.use_cache_store = bool(int(os.environ.get('USE_CACHE_STORE', 0)))
        logging.info(f'use_cache_store: {self.use_cache_store}')
        if self.use_cache_store:
            self.cache_store_rdma_mode = bool(int(os.environ.get('CACHE_STORE_RDMA_MODE', 1)))
            logging.info(f'cache_store_rdma_mode: {self.cache_store_rdma_mode}')

            self.load_cache_timeout_ms = int(os.environ.get('LOAD_CACHE_TIMEOUT_MS', 0))
            logging.info(f'load_cache_timeout_ms: {self.load_cache_timeout_ms}')

            self.decode_retry_times = int(os.environ.get('DECODE_RETRY_TIMES', 0))
            logging.info(f'decode_retry_times: {self.prefill_retry_times}')
            self.decode_retry_timeout_ms = int(os.environ.get('DECODE_RETRY_TIMEOUT_MS', 0))
            logging.info(f'decode_retry_timeout_ms: {self.decode_retry_timeout_ms}')

            self.rdma_connect_retry_times = int(os.environ.get('RDMA_CONNECT_RETRY_TIMES', 0))
            logging.info(f'rdma_connect_retry_times: {self.rdma_connect_retry_times}')

            self.decode_polling_kv_cache_step_ms = int(os.environ.get('DECODE_POLLING_KV_CACHE_STEP_MS', 30))
            logging.info(f'decode_polling_kv_cache_step_ms: {self.decode_polling_kv_cache_step_ms}')

            self.decode_use_async_load_cache = bool(int(os.environ.get('DECODE_USE_ASYNC_LOAD_CACHE', 1)))
            logging.info(f'decode_use_async_load_cache: {self.decode_use_async_load_cache}')

        self.scheduler_reserve_resource_ratio = int(os.environ.get('SCHEDUlER_RESERVE_RESOURCE_RATIO', 5))
        logging.info(f'scheduler_reserve_resource_ratio: {self.scheduler_reserve_resource_ratio}')
        self.reuse_cache = os.environ.get('REUSE_CACHE', None) == '1' or os.environ.get('USE_BLOCK_CACHE', None) == '1'
        logging.info(f'reuse_cache: {self.reuse_cache}')
        self.pre_allocate_op_mem = bool(int(os.environ.get('PRE_ALLOCATE_OP_MEM', 1)))
        logging.info(f'pre_allocate_op_mem: {self.pre_allocate_op_mem}')
        if bool(int(os.environ.get('INT8_KV_CACHE', 0))):
            self.kv_cache_data_type = WEIGHT_TYPE.INT8.to_str()
        elif self.quant_algo.isFp8() and not self.quant_algo.isGroupwise():
            self.kv_cache_data_type = WEIGHT_TYPE.FP8.to_str()
        else:
            self.kv_cache_data_type = self.data_type

        logging.info(f'kv_cache_data_type: {self.kv_cache_data_type}')
        logging.info(f'tp_split_emb_and_lm_head: {self.tp_split_emb_and_lm_head}')

        # use environment variables to update stop_words_str and stop_words_id
        env_stop_words_str = os.environ.get('STOP_WORDS_STR', None)
        env_stop_words_id = os.environ.get('STOP_WORDS_LIST', None)
        env_stop_words_str_list = json.loads(env_stop_words_str) if env_stop_words_str else []
        env_stop_words_id_list = json.loads(env_stop_words_id) if env_stop_words_id else []
        env_force_stop = os.environ.get('FORCE_STOP_WORDS', None)
        if env_force_stop and str_to_bool(env_force_stop):
            self.special_tokens.stop_words_str_list = env_stop_words_str_list
            self.special_tokens.stop_words_id_list = env_stop_words_id_list
        else:
            self.special_tokens.stop_words_str_list = self.special_tokens.stop_words_str_list + env_stop_words_str_list
            self.special_tokens.stop_words_id_list = self.special_tokens.stop_words_id_list + env_stop_words_id_list

        logging.info(f"use stop_words_str_list [{self.special_tokens.stop_words_str_list }]," \
                        f" stop_words_id_list [{self.special_tokens.stop_words_id_list}]")