maga_transformer/config/gpt_init_model_parameters.py (641 lines of code) (raw):

from typing import Dict, Any, List, Optional, Set import os import json import torch import logging import typing # make sure so init from dataclasses import dataclass, field, fields from enum import Enum from maga_transformer.utils.util import str_to_bool, closest_power_of_2 from maga_transformer.utils.weight_type import WEIGHT_TYPE from maga_transformer.config.task_type import TaskType, check_task_type from maga_transformer.distribute.worker_info import ParallelInfo, g_parallel_info, g_master_info, g_worker_info, WORKER_INFO_PORT_NUM from maga_transformer.distribute.gang_info import get_gang_info, GangInfo from maga_transformer.ops import GptInitParameter, QuantAlgo, SpecialTokens, MlaOpsType, EplbMode from maga_transformer.utils.gemm_utils.cutlass_config import load_cutlass_gemm_config updated_params: Set[str] = set() def get_pad_size(size: int , align_size: int): return (align_size - (size % align_size)) % align_size class DataClassBase: @classmethod def from_dict(cls, kvs: Dict[str, Any]): n_kvs = {k: v for k, v in kvs.items() if k in {f.name for f in fields(cls)}} # 兼容老的sparse config使用的key 没有加layer for k, v in kvs.items(): if k in ["head_num", "inter_size"] and isinstance(v, list): n_kvs.update({"layer_"+k : v}) data_class = cls(**n_kvs) return data_class mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]] @dataclass class SparseConfig(DataClassBase): layer_num: int = 0 layer_head_num: List[int] = field(default_factory=lambda: []) layer_inter_size: List[int] = field(default_factory=lambda: []) def check(self) -> bool: if self.layer_num == 0: logging.info("sparse config layer_num must not be empty") return False if len(self.layer_head_num) != self.layer_num: logging.info(f"sparse config layer_num and head_num must match, layer_num: {self.layer_num}, head_num: {self.layer_head_num}") return False if len(self.layer_inter_size) != self.layer_num: logging.info(f"sparse config layer_num and inter_size must match, layer_num: {self.layer_num}, inter_size: {self.layer_inter_size}") return False return True class VitParameters: # config includes origin vit config in ckpt/config.json config: Dict[str, Any] = {} special_token_ids: Dict[str, Any] = {} special_tokens: Dict[str, Any] = {} vit_weights: Any = None class TemplateType(Enum): chat = "chat" vqa = "vqa" base = "image" class ConfigMode(Enum): SimpleMode = 1 ComplexMode = 2 class GptInitModelParameters: __slots__ = { "gpt_init_params", "_model_related_types", "has_lm_head_bias", "src_quantization_bit", "ptuning_path", "tp_split_emb_and_lm_head", "mm_related_params", "lora_infos", "multi_task_prompt", "normalize_lm_head_weight", "ref_module", "ref_dict", "tie_word_embeddings", "need_ffn_act_scale", "task_type", "add_special_tokens", "template_type", "build_position_ids", "routed_scaling_factor", "is_ft_style_weight", "vit_run_batch", "phy2log", "is_mtp", "num_nodes", "use_qk_norm", "enable_merge_w13" } # copy from maga_transformer/ops/libth_transformer.pyi for python intelligence activation_type: str add_bias_linear: bool block_nums: int cache_store_connect_port: int cache_store_listen_port: int cache_store_rdma_connect_port: int cache_store_rdma_listen_port: int cache_store_rdma_mode: bool ckpt_path: str cross_attn_input_len: int data_type: str decode_polling_kv_cache_step_ms: int decode_retry_timeout_ms: int decode_retry_times: int decode_use_async_load_cache: bool deepseek_mscale_all_dim: float deepseek_rope_mscale: float dp_rank: int dp_size: int dp_tp_nccl_port: int embedding_size: int enable_eplb: bool enable_fast_gen: bool enable_partial_fallback: bool enable_sp: bool enable_speculative_decoding: bool ep_rank: int ep_size: int eplb_mode: EplbMode eplb_update_time: int expert_num: int fast_gen_max_context_len: int ffn_tp_nccl_port: int ffn_tp_rank: int ffn_tp_size: int gen_num_per_circle: int has_lm_head: bool has_moe_norm: bool has_positional_encoding: bool has_post_decoder_layernorm: bool has_pre_decoder_layernorm: bool head_num: int head_num_kv: int hidden_size: int http_port: int include_sep_tokens: bool input_embedding_scalar: float input_vocab_size: int inter_padding_size: int inter_size: int is_causal: bool is_multimodal: bool is_sparse_head: bool kv_cache_data_type: str kv_cache_mem_mb: int kv_lora_rank: int layer_head_num: list[int] layer_head_num_kv: list[int] layer_inter_padding_size: list[int] layer_inter_size: list[int] layer_num: int layernorm_eps: float layernorm_type: str load_balance_policy_name: str load_cache_timeout_ms: int local_rank: int logit_scale: float max_context_batch_size: int max_generate_batch_size: int max_rpc_timeout_ms: int max_seq_len: int mla_ops_type: MlaOpsType mm_position_ids_style: int mm_sep_tokens: list[list[int]] model_name: str model_rpc_port: int moe_inter_padding_size: int moe_k: int moe_layer_index: list[int] moe_n_group: int moe_normalize_expert_scale: bool moe_style: int moe_topk_group: int mrope_section: list[int] nccl_ip: str nope_head_dim: int norm_type: str num_layers: int num_valid_layer: int org_embedding_max_pos: int pd_sep_enable_fallback: bool pd_separation: bool phy_exp_num: int position_id_len_factor: int position_ids_style: int pre_allocate_op_mem: bool pre_seq_len: int prefill_max_wait_timeout_ms: int prefill_retry_timeout_ms: int prefill_retry_times: int prefix_projection: bool py_eplb: typing.Any q_lora_rank: int q_scaling: float qk_norm: bool quant_algo: QuantAlgo rdma_connect_retry_times: int remote_rpc_server_port: int reserve_runtime_mem_mb: int residual_scalar: float reuse_cache: bool reverse_e_h_norm: bool rope_head_dim: int rotary_embedding_base: float rotary_embedding_dim: int rotary_embedding_mscale: float rotary_embedding_offset: int rotary_embedding_scale: float rotary_embedding_style: int rotary_factor1: float rotary_factor2: float scheduler_reserve_resource_ratio: int scoring_func: int seq_size_per_block: int size_per_head: int softmax_extra_scale: float special_tokens: SpecialTokens tokenizer_path: str tp_nccl_port: int num_nodes: int tp_rank: int tp_size: int type_vocab_size: int use_attention_linear_bias: bool use_cache_store: bool use_cross_attn: bool use_expert_attention: bool use_fp32_to_compute_logit: bool use_kvcache: bool use_logn_attn: bool use_mla: bool use_norm_attn_out_residual: bool use_norm_input_residual: bool using_hf_sampling: bool v_head_dim: int vit_separation: int vocab_size: int warm_up: bool warm_up_with_loss: bool worker_addrs: list[str] worker_grpc_addrs: list[str] worker_port_offset: int world_size: int def __init__(self, head_num: int, size_per_head: int, layer_num: int, max_seq_len: int, vocab_size: int, **kwargs: Any): hidden_size = head_num * size_per_head self.gpt_init_params = GptInitParameter( head_num, size_per_head, layer_num, max_seq_len, vocab_size, hidden_size ) self._model_related_types: Dict[str, str] = { "layernorm_type": "setLayerNormType", "norm_type": "setNormType", "activation_type": "setActivationType", "kv_cache_data_type": "setKvCacheDataType" } self.has_lm_head_bias = False self.normalize_lm_head_weight = False self.src_quantization_bit = 0 self.tp_split_emb_and_lm_head = True self.ptuning_path = None self.multi_task_prompt = None self.pre_seq_len = 0 self.prefix_projection = False self.mm_related_params: VitParameters = VitParameters() self.ref_module: Optional[torch.nn.Module] = None self.ref_dict: Dict[str, torch.Tensor] = {} self.task_type = TaskType.LANGUAGE_MODEL self.tie_word_embeddings = False self.need_ffn_act_scale = False self.nccl_ip = g_master_info.ip self.tp_nccl_port = g_master_info.tp_nccl_port self.dp_tp_nccl_port = g_master_info.dp_tp_nccl_port self.ffn_tp_nccl_port = g_master_info.ffn_tp_nccl_port self.model_rpc_port = g_worker_info.rpc_server_port self.http_port = g_worker_info.http_port self.cache_store_listen_port = g_worker_info.cache_store_listen_port self.cache_store_connect_port = g_worker_info.cache_store_connect_port self.cache_store_rdma_listen_port = g_worker_info.cache_store_rdma_listen_port self.cache_store_rdma_connect_port = g_worker_info.cache_store_rdma_connect_port self.remote_rpc_server_port = g_worker_info.remote_rpc_server_port self.worker_port_offset = WORKER_INFO_PORT_NUM self.add_special_tokens = True self.template_type = TemplateType.chat self.build_position_ids = False self.routed_scaling_factor = 1.0 self.vit_run_batch = False self.is_ft_style_weight = False self.is_multimodal = False self.model_name = "" self.world_size = g_parallel_info.world_size self.phy2log: List[List[int]] = [] self.enable_eplb = self.eplb_mode != EplbMode.NONE self.is_mtp = False self.use_qk_norm = False self.enable_merge_w13 = False for k, v in kwargs.items(): setattr(self, k, v) # read and write directly through GptInitModelParameters.k def __getattr__(self, k: str): return getattr(self.gpt_init_params, k) def __setattr__(self, k: str, v: Any): updated_params.add(k) if k in self.__slots__: object.__setattr__(self, k, v) elif v is not None: self.gpt_init_params.__setattr__(k, v) if k in self._model_related_types: getattr(self.gpt_init_params, self._model_related_types[k])() def update(self, update_params: Dict[str, Any]): for k, v in update_params.items(): setattr(self, k, v) return self def update_worker_addrs(self): worker_addrs = [] worker_grpc_addrs = [] for member in get_gang_info().members: logging.info(f"member world rank: {member.world_rank}, member local rank: {member.local_rank}, local rank: {self.local_rank}, " \ f"tp_size: {self.tp_size}, dp_size: {self.dp_size}, dp_rank: {self.dp_rank}") if int((member.world_rank / self.tp_size) % self.dp_size) == self.dp_rank: worker_addrs.append(f'{member.ip}:{member.cache_store_listen_port}:{member.cache_store_rdma_listen_port}') worker_grpc_addrs.append(f'{member.ip}:{member.rpc_server_port}') logging.info(f"append member for pd sep " \ f"{member.ip}:{member.rpc_server_port}, {member.cache_store_listen_port}, " \ f"{member.cache_store_rdma_listen_port} to local rank {self.local_rank}, world rank {member.world_rank}") self.worker_grpc_addrs = worker_grpc_addrs self.worker_addrs = worker_addrs def update_config_with_sparse_config(self, ckpt_path: str): sparse_config_file = None sparse_config = None if os.path.exists(os.path.join(ckpt_path, "config.json")): sparse_config_file = os.path.join(ckpt_path, "config.json") if os.environ.get('SPARSE_CONFIG_FILE', None) is not None: sparse_config_file = os.environ['SPARSE_CONFIG_FILE'] if sparse_config_file is not None: logging.info(f"read sparse config from: {sparse_config_file}") with open(sparse_config_file, 'r') as reader: sparse_config_json = json.loads(reader.read()) sparse_config = SparseConfig.from_dict(sparse_config_json) if sparse_config and sparse_config.check(): self.layer_num = sparse_config.layer_num self.layer_head_num = sparse_config.layer_head_num self.layer_head_num_kv = sparse_config.layer_head_num self.layer_inter_size = sparse_config.layer_inter_size self.is_sparse_head = True def update_inter_padding_size(self, tp_size: int, ep_size: int, dp_size: int): if tp_size * dp_size != ep_size: raise ValueError(f"tp_size:{tp_size} * dp_size:{dp_size} != ep_size:{ep_size}") # new tp_size just only for moe if self.quant_algo.isGroupwise(): align_size = tp_size * self.quant_algo.getGroupSize() moe_align_size = self.quant_algo.getGroupSize() else: align_size = tp_size * 64 moe_align_size = 64 if self.layer_inter_size: layer_inter_padding_size = [] for idx in range(len(self.layer_inter_size)): inter_size = self.layer_inter_size[idx] layer_inter_padding_size.append(inter_size + (get_pad_size(inter_size, align_size) if self.quant_algo.isQuant() else 0)) self.layer_inter_padding_size = layer_inter_padding_size self.inter_padding_size = \ self.inter_size + (get_pad_size(self.inter_size, align_size) if self.quant_algo.isQuant() else 0) if self.head_num_kv <= 0: self.head_num_kv = self.head_num if self.inter_padding_size <= 0: self.inter_padding_size = self.inter_size if self.moe_inter_padding_size <= 0: self.moe_inter_padding_size = self.inter_size if self.moe_inter_padding_size > 0: moe_align_size = moe_align_size if self.quant_algo.isQuant() else 8 self.moe_inter_padding_size = self.moe_inter_padding_size + (get_pad_size(self.moe_inter_padding_size, moe_align_size)) logging.info(f"update_inter_padding_size: {self.inter_padding_size}, moe_inter_padding_size: {self.moe_inter_padding_size}, layer_inter_size: {self.layer_inter_size}") def update_task_prompt_tokens_id(self, tokenizer): if self.multi_task_prompt: for info in self.multi_task_prompt: task_id: str = str(info['task_id']) prompt: str = info['prompt'] tokens_id = tokenizer.encode(prompt) self.insertMultiTaskPromptTokens(task_id, tokens_id) def update_task_prompt_config(self): prompt_file_path = os.environ.get('MULTI_TASK_PROMPT', None) if not prompt_file_path: self.multi_task_prompt = None else: with open(prompt_file_path, 'r') as reader: multi_task_prompt = json.loads(reader.read(), strict=False) self.multi_task_prompt = multi_task_prompt return prompt_str = os.environ.get('MULTI_TASK_PROMPT_STR', None) if not prompt_str: self.multi_task_prompt = None else: self.multi_task_prompt = json.loads(prompt_str, strict=False) return def update_task_type_use_kvcache(self): self.task_type = check_task_type(self.ckpt_path) self.setTaskType(self.task_type.value) self.use_kvcache = (self.task_type == TaskType.LANGUAGE_MODEL) logging.info(f"model task type: {self.task_type}, use_kvcache: {self.use_kvcache}") def update_weight_style(self, ckpt_path: str): if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")): meta_file = os.path.join(ckpt_path, "model.safetensors.index.json") logging.info(f"read weight style from: {meta_file}") with open(meta_file, 'r') as reader: meta_json = json.loads(reader.read()) self.is_ft_style_weight = meta_json.get("is_ft_style_weight", False) def update_common(self, ckpt_path: str, lora_infos: Optional[Dict[str, str]], ptuning_path: Optional[str], tokenizer_path: str, int8_mode: bool, data_type: WEIGHT_TYPE, max_seq_len: int, seq_size_per_block: int, gen_num_per_circle: int, ref_module: Optional[torch.nn.Module] = None, ref_dict: Dict[str, torch.Tensor] = {}, parallel_info: ParallelInfo=g_parallel_info, config_mode: ConfigMode = ConfigMode.ComplexMode, gang_info: Optional[GangInfo] = None): self.tp_size = parallel_info.tp_size self.tp_rank = parallel_info.tp_rank self.ep_size = parallel_info.ep_size self.ep_rank = parallel_info.ep_rank self.dp_size = parallel_info.dp_size self.dp_rank = parallel_info.dp_rank self.ffn_tp_rank = parallel_info.ffn_tp_rank self.ffn_tp_size = parallel_info.ffn_tp_size self.enable_sp = parallel_info.ffn_sp_size > 1 self.local_rank = parallel_info.local_rank self.eplb_update_time = int(os.environ.get("EPLB_UPDATE_TIME", 5000)) self.eplb_mode = EplbMode.__members__[os.environ.get('EPLB_MODE', 'NONE')] self.phy_exp_num = int(os.environ.get("REDUNDANT_EXPERT", 0)) + self.expert_num self.enable_merge_w13 = os.getenv('ENABLE_MERGE_W13', '0').lower() == '1' logging.info(f"phy_exp_num: {self.phy_exp_num}, use merge w13: {self.enable_merge_w13}") if gang_info is not None: self.num_nodes = gang_info.num_nodes else: try: self.num_nodes = get_gang_info().num_nodes except: self.num_nodes = 1 self.ckpt_path = ckpt_path self.lora_infos = lora_infos self.tokenizer_path = tokenizer_path if not self.quant_algo.isQuant() and int8_mode: self.quant_algo.setQuantAlgo("weight_only_per_col", 8, 0) self.data_type = data_type.to_str() self.gen_num_per_circle = gen_num_per_circle self.ptuning_path = ptuning_path self.ref_module = ref_module self.ref_dict = ref_dict if max_seq_len != 0: self.max_seq_len = max_seq_len if self.max_seq_len < 1: self.max_seq_len = 1024 logging.info(f'max_seq_len: {self.max_seq_len}') self.update_task_type_use_kvcache() logging.info(f"config_mode = {config_mode}") if config_mode == ConfigMode.SimpleMode: return self.update_worker_addrs() self.update_config_with_sparse_config(ckpt_path) self.update_inter_padding_size(self.tp_size, self.ep_size, self.dp_size) self.update_task_prompt_config() self.update_weight_style(ckpt_path) load_cutlass_gemm_config(self.quant_algo) hack_layer_num = int(os.environ.get('HACK_LAYER_NUM', 0)) if (hack_layer_num): logging.info(f"hack layernum to {hack_layer_num}") self.layer_num = hack_layer_num self.seq_size_per_block = closest_power_of_2(int(max(seq_size_per_block, self.max_seq_len // 128))) # must be 2^n self.seq_size_per_block = int(os.environ.get('SEQ_SIZE_PER_BLOCK', self.seq_size_per_block)) logging.info(f'seq_size_per_block: {self.seq_size_per_block}') self.max_generate_batch_size = int(os.environ.get('CONCURRENCY_LIMIT', 128)) logging.info(f'max_generate_batch_size: {self.max_generate_batch_size}') self.max_context_batch_size = int(os.environ.get('MAX_CONTEXT_BATCH_SIZE', 1)) logging.info(f'max_context_batch_size: {self.max_context_batch_size}') self.reserve_runtime_mem_mb = int(os.environ.get('RESERVER_RUNTIME_MEM_MB', 128)) logging.info(f'reserve_runtime_mem_mb: {self.reserve_runtime_mem_mb}') self.kv_cache_mem_mb = int(os.environ.get('KV_CACHE_MEM_MB', -1)) logging.info(f'kv_cache_mem_mb: {self.kv_cache_mem_mb}') self.block_nums = int(os.environ.get('TEST_BLOCK_NUM', 0)) logging.info(f'block_nums: {self.block_nums}') if os.environ.get('TEST_LAYER_NUM'): logging.info(f'replace model layer with TEST_LAYER_NUM: {os.environ.get("TEST_LAYER_NUM")}') self.layer_num = int(os.environ.get('TEST_LAYER_NUM', self.layer_num)) self.enable_partial_fallback = bool(int(os.environ.get('ENABLE_PARTIAL_FALLBACK', 0))) logging.info(f'enable_partial_fallback: {self.enable_partial_fallback}') self.enable_fast_gen = bool(int(os.environ.get('ENABLE_FAST_GEN', 0))) logging.info(f'enable_fast_gen: {self.enable_fast_gen}') self.warm_up = bool(int(os.environ.get('WARM_UP', 1))) logging.info(f'warm_up: {self.warm_up}') self.warm_up_with_loss = bool(int(os.environ.get('WARM_UP_WITH_LOSS', 0))) logging.info(f'warm_up_with_loss: {self.warm_up_with_loss}') self.vit_separation = int(os.environ.get('VIT_SEPARATION', 0)) logging.info(f'vit_separation: {self.vit_separation}') self.fast_gen_max_context_len = int(os.environ.get('FAST_GEN_MAX_CONTEXT_LEN', 1024)) logging.info(f'fast_gen_max_context_len: {self.fast_gen_max_context_len}') self.max_rpc_timeout_ms = int(os.environ.get('MAX_RPC_TIMEOUT_MS', 0)) logging.info(f'max_rpc_timeout_ms: {self.max_rpc_timeout_ms}') self.pd_separation = bool(int(os.environ.get('PD_SEPARATION', 0))) logging.info(f'pd_separation: {self.pd_separation}') if self.pd_separation: self.prefill_retry_times = int(os.environ.get('PREFILL_RETRY_TIMES', 0)) logging.info(f'prefill_retry_times: {self.prefill_retry_times}') self.prefill_retry_timeout_ms = int(os.environ.get('PREFILL_RETRY_TIMEOUT_MS', 0)) logging.info(f'prefill_retry_timeout_ms: {self.prefill_retry_timeout_ms}') self.prefill_max_wait_timeout_ms = int(os.environ.get('PREFILL_MAX_WAIT_TIMEOUT_US', 600 * 1000 * 1000)) logging.info(f'prefill_max_wait_timeout_ms: {self.prefill_max_wait_timeout_ms}') self.pd_sep_enable_fallback = bool(int(os.environ.get('PD_SEP_ENABLE_FALLBACK', 0))) logging.info(f'pd_sep_enable_fallback: {self.pd_sep_enable_fallback}') self.load_balance_policy_name = os.environ.get('LOAD_BALANCE_POLICY_NAME', "RR") logging.info(f'load_balance_policy_name: {self.load_balance_policy_name}') policy_list = ["RR", "WRR"] if not self.load_balance_policy_name in policy_list: raise Exception(f"load_balance_policy_name {self.load_balance_policy_name} " \ f"is not right, it must in {policy_list}") self.sync_status_interval_ms = int(os.environ.get('SYNC_STATUS_INTERVAL_MS', 50)) logging.info(f'sync_status_interval_ms: {self.sync_status_interval_ms}') self.use_cache_store = bool(int(os.environ.get('USE_CACHE_STORE', 0))) logging.info(f'use_cache_store: {self.use_cache_store}') if self.use_cache_store: self.cache_store_rdma_mode = bool(int(os.environ.get('CACHE_STORE_RDMA_MODE', 1))) logging.info(f'cache_store_rdma_mode: {self.cache_store_rdma_mode}') self.load_cache_timeout_ms = int(os.environ.get('LOAD_CACHE_TIMEOUT_MS', 0)) logging.info(f'load_cache_timeout_ms: {self.load_cache_timeout_ms}') self.decode_retry_times = int(os.environ.get('DECODE_RETRY_TIMES', 0)) logging.info(f'decode_retry_times: {self.prefill_retry_times}') self.decode_retry_timeout_ms = int(os.environ.get('DECODE_RETRY_TIMEOUT_MS', 0)) logging.info(f'decode_retry_timeout_ms: {self.decode_retry_timeout_ms}') self.rdma_connect_retry_times = int(os.environ.get('RDMA_CONNECT_RETRY_TIMES', 0)) logging.info(f'rdma_connect_retry_times: {self.rdma_connect_retry_times}') self.decode_polling_kv_cache_step_ms = int(os.environ.get('DECODE_POLLING_KV_CACHE_STEP_MS', 30)) logging.info(f'decode_polling_kv_cache_step_ms: {self.decode_polling_kv_cache_step_ms}') self.decode_use_async_load_cache = bool(int(os.environ.get('DECODE_USE_ASYNC_LOAD_CACHE', 1))) logging.info(f'decode_use_async_load_cache: {self.decode_use_async_load_cache}') self.scheduler_reserve_resource_ratio = int(os.environ.get('SCHEDUlER_RESERVE_RESOURCE_RATIO', 5)) logging.info(f'scheduler_reserve_resource_ratio: {self.scheduler_reserve_resource_ratio}') self.reuse_cache = os.environ.get('REUSE_CACHE', None) == '1' or os.environ.get('USE_BLOCK_CACHE', None) == '1' logging.info(f'reuse_cache: {self.reuse_cache}') self.pre_allocate_op_mem = bool(int(os.environ.get('PRE_ALLOCATE_OP_MEM', 1))) logging.info(f'pre_allocate_op_mem: {self.pre_allocate_op_mem}') if bool(int(os.environ.get('INT8_KV_CACHE', 0))): self.kv_cache_data_type = WEIGHT_TYPE.INT8.to_str() elif self.quant_algo.isFp8() and not self.quant_algo.isGroupwise(): self.kv_cache_data_type = WEIGHT_TYPE.FP8.to_str() else: self.kv_cache_data_type = self.data_type logging.info(f'kv_cache_data_type: {self.kv_cache_data_type}') logging.info(f'tp_split_emb_and_lm_head: {self.tp_split_emb_and_lm_head}') # use environment variables to update stop_words_str and stop_words_id env_stop_words_str = os.environ.get('STOP_WORDS_STR', None) env_stop_words_id = os.environ.get('STOP_WORDS_LIST', None) env_stop_words_str_list = json.loads(env_stop_words_str) if env_stop_words_str else [] env_stop_words_id_list = json.loads(env_stop_words_id) if env_stop_words_id else [] env_force_stop = os.environ.get('FORCE_STOP_WORDS', None) if env_force_stop and str_to_bool(env_force_stop): self.special_tokens.stop_words_str_list = env_stop_words_str_list self.special_tokens.stop_words_id_list = env_stop_words_id_list else: self.special_tokens.stop_words_str_list = self.special_tokens.stop_words_str_list + env_stop_words_str_list self.special_tokens.stop_words_id_list = self.special_tokens.stop_words_id_list + env_stop_words_id_list logging.info(f"use stop_words_str_list [{self.special_tokens.stop_words_str_list }]," \ f" stop_words_id_list [{self.special_tokens.stop_words_id_list}]") def get_params_dict(self): res: Dict[str, Any] = {} for name in updated_params: res[name] = eval('self.' + name) return res def eval_model_size(self): layer_param_bytes = 2 if self.quant_algo.getWeightBits() == 8: layer_param_bytes = 1 elif self.quant_algo.getWeightBits() == 4: layer_param_bytes = 0.54 model_size = self.word_emb_param_count * 2 + \ self.layer_weight_param_count * layer_param_bytes + \ self.gpt_init_params.hidden_size * layer_param_bytes + \ self.word_emb_param_count * 2 # maybe some model donot have lm_head kv_cache_mem_size = self._eval_kv_cache_mem_size() runtime_buffer = self._eval_runtime_buffer_mem_size() total_size = model_size + kv_cache_mem_size + runtime_buffer logging.info(f"total_size(Bytes): {total_size}, model_size:{model_size}, kv_cache_mem_size:{kv_cache_mem_size}, runtime_buffer:{runtime_buffer}") return total_size def _eval_kv_cache_mem_size(self): if self.task_type != TaskType.LANGUAGE_MODEL: return 0 kv_cache_bytes = 1 if self.kv_cache_data_type in [WEIGHT_TYPE.FP8.to_str(), WEIGHT_TYPE.INT8.to_str()] else 2 kv_cache_size = 2 * self.layer_num * self.head_num_kv * self.size_per_head * kv_cache_bytes * self.max_seq_len return kv_cache_size def _eval_runtime_buffer_mem_size(self): input_buffer = self.max_seq_len * self.gpt_init_params.hidden_size qkv_gemm_buffer_size = self.max_seq_len * (self.head_num_kv*2 + self.head_num_kv) * self.size_per_head attn_buffer_size = self.max_seq_len * self.gpt_init_params.hidden_size ffn_export_num = self.expert_num if self.gpt_init_params.moe_k else 1 ffn_w_count = 1 if self.activation_type == 'gelu' else 2 ffn_buffer = (self.max_seq_len * self.gpt_init_params.hidden_size + ffn_w_count* self.max_seq_len * self.inter_size)*ffn_export_num return input_buffer + qkv_gemm_buffer_size + attn_buffer_size + ffn_buffer @property def model_param_count(self): return self.word_emb_param_count*2 + self.layer_weight_param_count + self.gpt_init_params.hidden_size @property def word_emb_param_count(self): return self.vocab_size * self.gpt_init_params.hidden_size @property def layer_weight_param_count(self): hidden_size = self.gpt_init_params.hidden_size layer_weight_param_count = 0 # qkv if self.layer_head_num and isinstance(self.layer_head_num, list): for head_num in self.layer_head_num: layer_weight_param_count = layer_weight_param_count + head_num * self.size_per_head * hidden_size *3 elif self.head_num_kv != self.head_num: layer_weight_param_count = layer_weight_param_count + self.layer_num * hidden_size * hidden_size + \ self.layer_num * (self.head_num_kv * self.size_per_head) * 2 else: layer_weight_param_count = layer_weight_param_count + self.layer_num * hidden_size * hidden_size *3 # attn_o_w if self.layer_head_num and isinstance(self.layer_head_num, list): for head_num in self.layer_head_num: layer_weight_param_count = layer_weight_param_count + head_num * self.size_per_head * hidden_size else: layer_weight_param_count = layer_weight_param_count + self.layer_num * hidden_size * hidden_size # ffn w1, w2, w3 ffn_export_num = self.expert_num if self.expert_num > 0 else 1 ffn_w_count = 2 if self.activation_type == 'gelu' else 3 if self.layer_inter_size and isinstance(self.layer_inter_size, list): for layer_inter_size in self.layer_inter_size: if self.moe_style == 1: layer_weight_param_count = layer_weight_param_count + layer_inter_size * hidden_size * ffn_w_count * ffn_export_num else: layer_weight_param_count = layer_weight_param_count + layer_inter_size * hidden_size * ffn_w_count if self.moe_style == 2: layer_weight_param_count = layer_weight_param_count + self.moe_inter_padding_size * hidden_size * ffn_w_count * ffn_export_num else: if self.moe_style == 1: layer_weight_param_count = layer_weight_param_count + self.layer_num * self.inter_size * hidden_size * ffn_w_count * ffn_export_num else: layer_weight_param_count = layer_weight_param_count + self.layer_num * self.inter_size * hidden_size * ffn_w_count if self.moe_style == 2: layer_weight_param_count = layer_weight_param_count + len(self.moe_layer_index) * self.moe_inter_padding_size * hidden_size * ffn_w_count * ffn_export_num if ffn_export_num > 1: layer_weight_param_count = layer_weight_param_count + len(self.moe_layer_index) * hidden_size * ffn_export_num # other small tensor layer_weight_param_count = layer_weight_param_count + self.layer_num * hidden_size * 11 return layer_weight_param_count