maga_transformer/models/base_model.py (467 lines of code) (raw):
import os
import torch
import json
import logging
import math
import torch.nn.functional as F
from pydantic import BaseModel as PyBaseModel
from typing import Any, Dict, List, Optional, Union, NamedTuple
from transformers import PreTrainedTokenizerBase, AutoTokenizer
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters, ConfigMode
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.config.task_type import TaskType
from maga_transformer.distribute.gang_info import get_gang_info
from maga_transformer.distribute.worker_info import ParallelInfo, g_parallel_info
from maga_transformer.models.downstream_modules.custom_module import CustomModule
from maga_transformer.models.downstream_modules.utils import create_custom_module
from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin
from maga_transformer.utils.fuser import fetch_remote_file_to_local
from maga_transformer.utils.util import to_torch_dtype
from maga_transformer.utils.model_weight import W
from maga_transformer.model_loader.model_weight_info import ModelDeployWeightInfo, ModelWeights
from maga_transformer.model_loader.load_config import LoadConfig
from maga_transformer.model_loader.loader import ModelLoader, get_model_loader
from maga_transformer.utils.weight_type import WEIGHT_TYPE
from maga_transformer.utils.multimodal_util import MultimodalInput
from maga_transformer.utils.database import CkptDatabase
from maga_transformer.utils.time_util import Timer
from maga_transformer.eplb.ep_balancer import ExpertBalancer
FT_DEFAULT_MAX_NEW_TOKENS = 2048
class EmbeddingOutput:
text_embedding: torch.Tensor
extra_input: Optional[torch.Tensor]
def __init__(self, text_embedding: torch.Tensor, extra_input: Optional[List[torch.Tensor]]):
self.text_embedding = text_embedding
if extra_input:
try:
self.extra_input = torch.concat(extra_input)
self.extra_input = torch.Tensor(self.extra_input.shape[1:])
except:
raise Exception("Extra input must have same shape except dim 0")
else:
self.extra_input = None
# single batch prompt input
class GenerateInput(PyBaseModel):
request_id: int
token_ids: torch.Tensor
mm_inputs: List[MultimodalInput]
generate_config: GenerateConfig
tokenizer: Any = None # TODO: remove this
prefix_length: int = 0
token_type_ids: List[int] = []
class Config:
arbitrary_types_allowed = True
@property
def input_length(self):
return self.token_ids.shape[-1]
@property
def prompt_length(self):
return self.token_ids.shape[-1] - self.prefix_length
def update_prefix(self, prefix_tokens: torch.Tensor):
self.token_ids = torch.concat([prefix_tokens, self.token_ids], dim=0)
self.prefix_length = prefix_tokens.nelement()
class AuxInfo(PyBaseModel):
cost_time: float = 0
iter_count: int = 0
prefix_len: int = 0
input_len: int = 0
reuse_len: int = 0
output_len: int = 0
step_output_len: int = 0
fallback_tokens: int = 0
fallback_times: int = 0
first_token_cost_time: float = 0
wait_time: float = 0
pd_sep: bool = False
cum_log_probs: List[float] = []
beam_responses: List[str] = []
softmax_probs: List[float] = []
class GenerateOutput(PyBaseModel):
hidden_states: Optional[torch.Tensor] = None
output_ids: Optional[torch.Tensor] = None
input_ids: Optional[torch.Tensor] = None
finished: bool = False
aux_info: AuxInfo = AuxInfo()
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
all_probs: Optional[torch.Tensor] = None
class Config:
arbitrary_types_allowed = True
class GenerateOutputs(PyBaseModel):
generate_outputs: List[GenerateOutput] = []
class GenerateResponse(PyBaseModel):
generate_outputs: GenerateOutputs
generate_texts: List[str]
class GenerateContext(NamedTuple):
inputs: Any
input_embeds: Any
attention_mask: Any
pad_lengths: Any
input_lengths: Any
memory_length: Any
sampler: Any
batch_size: Any
beam_width: Any
max_input_length: Any
finished: Any
sequence_lengths: Any
gen_length: Any
cum_log_probs: Any
extra_args: Any
all_start_time: Any
cache_indirection: Any
output_token_ids: Any
class ModelConfig:
def __init__(
self,
model_type: str = "",
ckpt_path: str = "",
tokenizer_path: str = "",
weight_type: WEIGHT_TYPE = WEIGHT_TYPE.FP16,
act_type: WEIGHT_TYPE = WEIGHT_TYPE.FP16,
max_seq_len: int = 0,
seq_size_per_block: int = 8,
gen_num_per_circle: int = 1,
ptuning_path: Optional[str] = None,
lora_infos: Optional[Dict[str, str]] = None,
ref_module: Optional[torch.nn.Module] = None,
ref_dict: Dict[str, torch.Tensor] = {},
sp_type: str = "",
):
self.model_type: str = model_type
self.ckpt_path: str = ckpt_path
self.tokenizer_path: str = tokenizer_path
self.weight_type: WEIGHT_TYPE = weight_type
self.act_type: WEIGHT_TYPE = act_type
self.max_seq_len: int = max_seq_len
self.seq_size_per_block: int = seq_size_per_block
self.gen_num_per_circle: int = gen_num_per_circle
self.ptuning_path: Optional[str] = ptuning_path
self.lora_infos: Optional[Dict[str, str]] = lora_infos
self.ref_module: Optional[torch.nn.Module] = ref_module
self.ref_dict: Dict[str, torch.Tensor] = ref_dict
self.sp_type: str = sp_type
@property
def int8_mode(self):
return True if self.weight_type == WEIGHT_TYPE.INT8 else False
def add_ref_module(self, ref_module: Optional[torch.nn.Module]):
self.ref_module = ref_module
def add_ref_dict(self, ref_dict: Dict[str, torch.Tensor]):
self.ref_dict = ref_dict
def _replace(self, **kwargs: Any):
for k, v in kwargs.items():
if k in self.__dict__:
self.__dict__[k] = v
return self
def get_slopes(n: int) -> List[float]:
def get_slopes_power_of_2(n: int) -> List[float]:
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + \
get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
class BaseModel(object):
config: GptInitModelParameters
vocab_size_padded: int
device: str
def __init__(self, config: GptInitModelParameters) -> None:
self.config = config
self.weight = None
self.linear_bias_slopes: Optional[torch.Tensor] = None
self.prefix_tokens: Optional[torch.Tensor] = None
self.tokenizer: Optional[PreTrainedTokenizerBase] = None
self.max_input_buffer_len: int = 0
self.task_type: TaskType = TaskType.LANGUAGE_MODEL
self.custom_module: Optional[CustomModule] = None
self.default_generate_config: GenerateConfig = GenerateConfig()
self.load_tokenizer()
def _load_to_device(self, parallel_info: ParallelInfo=g_parallel_info):
self.parallel_info = parallel_info
self.device = self.parallel_info.device
self.may_init_multimodal()
self.init_misc()
self.load(self.device)
@classmethod
def create_config(cls, model_config: ModelConfig,
parallel_info:ParallelInfo=g_parallel_info,
config_mode: ConfigMode = ConfigMode.ComplexMode) -> GptInitModelParameters:
config: GptInitModelParameters = cls._create_config(model_config.ckpt_path)
cls._load_quant_config(model_config.ckpt_path, config)
if config.hidden_size == 0:
config.hidden_size = config.size_per_head * config.head_num
config.update_common(
ckpt_path=model_config.ckpt_path,
tokenizer_path=model_config.tokenizer_path,
int8_mode=model_config.int8_mode,
data_type=model_config.act_type,
max_seq_len=model_config.max_seq_len,
seq_size_per_block=model_config.seq_size_per_block,
gen_num_per_circle=model_config.gen_num_per_circle,
lora_infos=model_config.lora_infos,
ptuning_path=model_config.ptuning_path,
ref_module=model_config.ref_module,
ref_dict=model_config.ref_dict,
parallel_info=parallel_info,
gang_info=get_gang_info(),
config_mode=config_mode
)
cls._update_config(config)
return config
@classmethod
def _create_config(cls, ckpt_path: str) -> GptInitModelParameters:
raise NotImplementedError()
@classmethod
def _update_config(cls, config: GptInitModelParameters):
pass
@staticmethod
def _load_quant_config(ckpt_path: str, config: GptInitModelParameters):
quant_config_path = os.path.join(ckpt_path, 'smoothquant.ini')
if os.path.exists(quant_config_path):
config.quant_algo.setQuantAlgo('smooth_quant', 0, 0)
per_tensor_config_path = os.path.join(ckpt_path, "pertensorquant.ini")
if os.path.exists(per_tensor_config_path):
config.quant_algo.setQuantAlgo('pertensor_quant', 0, 0)
config_path = os.path.join(ckpt_path, "config.json")
if not os.path.exists(config_path):
return
config_json = json.load(open(config_path))
quant_config = None
quant_method = None
if config_json.get("quantization_config", None):
quant_config = config_json["quantization_config"]
quant_method = quant_config['quant_method'].lower()
if config_json.get("quantization", None):
quant_config = config_json["quantization"]
quant_method = quant_config['quant_algo'].lower()
if quant_config is None:
return
group_size = quant_config['group_size'] if 'group_size' in quant_config else 0
bits = quant_config['bits'] if 'bits' in quant_config else 0
if quant_method == 'fp8':
bits = 8
if 'weight_block_size' in quant_config:
weight_block = quant_config.get("weight_block_size")
assert isinstance(weight_block, list) and all(element == weight_block[0] for element in weight_block), f"weight_block_size: {weight_block} must be same"
group_size = weight_block[0]
config.quant_algo.setQuantAlgo(quant_method, bits, group_size)
@classmethod
def from_config(cls, config: Any, parallel_info:ParallelInfo=g_parallel_info) -> 'BaseModel':
model = cls(config)
model._load_to_device(parallel_info)
return model
@staticmethod
def get_weight_cls() -> ModelDeployWeightInfo:
raise NotImplementedError
@property
def dtype(self) -> Union[str, torch.dtype]:
assert self.weight is not None
return self.weight.dtype
def may_init_multimodal(self):
if self.is_multimodal():
assert isinstance(self, MultiModalMixin) # for syntax check
self.config.is_multimodal = True
if self.parallel_info.tp_rank == 0:
self.init_multimodal(self.config)
def init_misc(self):
self.task_type = self.config.task_type
self.custom_module = self.load_custom_module()
self.compute_dtype: torch.dtype = to_torch_dtype(self.config.data_type)
def split_slopes_tp(self, slopes: torch.Tensor):
local_head_num = 1 if self.config.head_num == 1 else self.config.head_num // self.parallel_info.tp_size
start_pos = local_head_num * self.parallel_info.tp_rank
return slopes[start_pos: start_pos + local_head_num]
@classmethod
def get_tokenizer(cls, config: GptInitModelParameters):
assert config.tokenizer_path
return AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True)
def load_tokenizer(self):
if not self.config.tokenizer_path:
self.tokenizer = None
return
def error_handler(func: Any):
def wrapper(*args: Any, **kwargs: Any):
try:
return func(*args, **kwargs)
except Exception as e:
method_name = func.__name__
raise RuntimeError(f"{method_name} failed, with input args: {args}, kwargs: {kwargs}")
return wrapper
self.tokenizer = self.get_tokenizer(self.config)
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id:
self.config.special_tokens.eos_token_id = self.tokenizer.eos_token_id
self.config.update_task_prompt_tokens_id(self.tokenizer)
if getattr(self.tokenizer, 'encode', None):
self.tokenizer.encode = error_handler(self.tokenizer.encode)
if getattr(self.tokenizer, 'decode', None):
self.tokenizer.decode = error_handler(self.tokenizer.decode)
def is_multimodal(self) -> bool:
return isinstance(self, MultiModalMixin)
def init_database(self):
self.database = CkptDatabase(self.config.ckpt_path, self.config.ptuning_path)
def load_static_lora(self):
# static lora load
self.static_lora: bool = self.config.lora_infos is not None and len(self.config.lora_infos) == 1
if self.static_lora:
for name, path in self.config.lora_infos.items():
self.database.load_lora(name, path)
self.database.dump_lora_info()
def load_model_weight(self):
weights_info = self.get_weight_cls()(self.config, self.parallel_info.tp_size, self.parallel_info.tp_rank)
self.model_weights_loader = ModelLoader(weights_info, self.compute_dtype, self.database)
self.weight: ModelWeights = self.model_weights_loader.load_weights(device=self.device)
self._load_custom_module_weights(self.model_weights_loader)
def _load_weights(self,
ref_dict: Dict[str, torch.Tensor] = {}):
with Timer() as timer:
self.init_database()
self.load_static_lora()
self.load_model_weight()
logging.info(f'load weights time: {timer.cost_ms() / 1000 :.2f} s')
def load_custom_module(self) -> Optional[CustomModule]:
return create_custom_module(self.task_type, self.config, self.tokenizer)
def _load_custom_module_weights(self, model_weights_loader: ModelLoader):
if self.custom_module is not None:
tensor_names = self.custom_module.handler.tensor_info()
tensor_map: Dict[str, torch.Tensor] = {}
for name in tensor_names:
loaded_tensor = model_weights_loader.load_raw_tensor(name, device=self.device)
tensor_map[name] = loaded_tensor
self.weight.set_global_weight(name, loaded_tensor)
self.custom_module.handler.init(tensor_map)
def _initialize_weights(self):
assert (self.weight is not None)
embedding_weight = self.weight.global_weights.get(W.embedding, None)
if embedding_weight != None:
self.config.embedding_size = embedding_weight.shape[0]
logging.info(f"embedding_size is {self.config.embedding_size}, vocab size is {self.config.vocab_size}")
if self.config.vit_separation != 2 and self.is_multimodal():
self.load_mm_weight(self.compute_dtype, self.device)
if self.config.vit_separation != 1:
if self.task_type == TaskType.LANGUAGE_MODEL:
lm_head_w = self.weight.steal_global_weight(W.lm_head)
if lm_head_w == None:
lm_head_w = self.weight.global_weights[W.embedding]
if self.config.normalize_lm_head_weight:
lm_head_w = F.normalize(lm_head_w)
if self.config.logit_scale != 1.0:
lm_head_w = self.config.scale_logit * lm_head_w
self.weight.set_global_weight(W.lm_head, lm_head_w)
else:
# Some LLM can be used for other tasks, e.g. classification, in which case lm_head is not needed
self.weight.steal_global_weight(W.lm_head)
pos_weight = self.weight.global_weights.get(W.positional_embedding, None)
if pos_weight != None:
if pos_weight.shape[0] < self.config.max_seq_len:
raise Exception(f"positon_weight has shape: {pos_weight.shape}, but max_seq_len is: {self.config.max_seq_len} > {pos_weight.shape[0]}")
pos_weight = pos_weight[:self.config.max_seq_len].to(self.device)
self.weight.set_global_weight(W.positional_embedding, pos_weight)
if self.config.use_attention_linear_bias:
slopes = torch.Tensor(get_slopes(self.config.head_num))
slopes = self.split_slopes_tp(slopes)
self.linear_bias_slopes = slopes.to(torch.float).to(self.device)
self.weight.set_global_weight(W.linear_bias_slopes, self.linear_bias_slopes)
if self.config.quant_algo.isPerTensorQuant() and \
(self.weight.global_weights.get(W.pre_decoder_ln_static_quant, None) == None or \
self.weight.global_weights.get(W.pre_decoder_ln_static_quant_reciprocal, None) == None):
raise Exception("pre_decoder_ln_static_quant and pre_decoder_ln_static_quant_reciprocal \
are quired for per tensor quantization")
ModelLoader.force_clean_cuda_memory()
def _initialize_rope(self):
pass
def init_redundant_expert(self):
if self.config.expert_num == 0:
return
expert_num = self.config.expert_num
ep_size = self.parallel_info.ep_size
layer_num = self.config.layer_num
phy_exp_num = self.config.phy_exp_num
phy2log = LoadConfig.create_redundant_expert(layer_num=layer_num,
expert_num=expert_num,
phy_exp_num=phy_exp_num,
ep_size=ep_size,
num_nodes=self.config.num_nodes)
self.config.phy2log = phy2log
def init_eplb_weight(self, weight: ModelWeights):
expert_num = self.config.expert_num
redundant_expert = self.config.phy_exp_num - expert_num
layer_num = self.config.layer_num
phy2log = self.config.phy2log
if expert_num == 0 or redundant_expert == 0:
return
# init logic_expert_cnt and log2phy
for layer_id in range(layer_num):
logic_expert_cnt = torch.zeros((expert_num,), dtype=torch.int32)
log2phy = torch.empty((expert_num, redundant_expert + 1), dtype=torch.int32).fill_(-1)
layer_phy2log = phy2log[layer_id]
for phy_exp_id, expert_id in enumerate(layer_phy2log):
cnt = logic_expert_cnt[expert_id]
log2phy[expert_id, cnt] = phy_exp_id
logic_expert_cnt[expert_id] += 1
weight.weights[layer_id][W.logic_expert_cnt] = logic_expert_cnt.contiguous().to(self.device)
weight.weights[layer_id][W.log2phy] = log2phy.contiguous().to(self.device)
def init_eplb_config(self, compute_dtype: torch.dtype):
self.init_redundant_expert()
if self.config.enable_eplb:
model_path = None
if self.config.is_mtp:
model_path = self.config.ckpt_path
else:
model_path = fetch_remote_file_to_local(
os.environ.get(
"ORIGINAL_CHECKPOINT_PATH", self.config.ckpt_path
)
)
weights_info: ModelDeployWeightInfo = self.get_weight_cls()(self.config, self.parallel_info.tp_size, self.parallel_info.tp_rank)
ep_lb_database = CkptDatabase(model_path)
self.ep_balancer = ExpertBalancer(
weights_info=weights_info,
compute_dtype=compute_dtype,
database=ep_lb_database
)
self.config.py_eplb = self.ep_balancer
def load(self, device: str):
self.init_eplb_config(compute_dtype=self.compute_dtype)
self._load_weights()
self.init_eplb_weight(self.weight)
self._initialize_rope()
self._initialize_weights()
def dup_dim0_for_beam_search(self, t: torch.Tensor, beam_width: int) -> torch.Tensor:
shape = list(t.shape)
return t.unsqueeze(1).repeat([1, beam_width] + [1] * len(shape[1:])).reshape([-1] + shape[1:]).contiguous()
def extend_context_combo_token_types(self, token_types: List[int]) -> List[int]:
return []
def extend_generate_combo_token_types(self, combo_tokens: List[int]) -> List[int]:
return []
def create_context_position_ids(self, input_lengths: Union[List[int], torch.Tensor]):
return torch.concat([torch.arange(int(input_length), dtype=torch.int32) for input_length in input_lengths], dim=0)
def create_context_decoder_mask(self, input_lengths: List[int]):
batch_size = len(input_lengths)
max_input_length = max(input_lengths)
attention_mask = torch.ones(
(max_input_length, max_input_length), dtype=torch.bool, device=self.device)
if self.config.is_causal:
attention_mask = attention_mask.tril()
attention_mask = attention_mask.unsqueeze_(0).tile(batch_size, 1, 1).to(self.dtype)
for b, input_length in enumerate(input_lengths):
attention_mask[b, input_length:, ...] = 0
if not self.config.is_causal:
attention_mask[b, :, input_length: ]= 0
return attention_mask
def create_model_loader(self, parallel_info: ParallelInfo) -> ModelLoader:
self.parallel_info = parallel_info
self.device = self.parallel_info.device
self.may_init_multimodal()
self.init_misc()
self.init_database()
self.load_static_lora()
self.init_eplb_config(compute_dtype=self.compute_dtype)
tp_rank = self.parallel_info.tp_rank
tp_size = self.parallel_info.tp_size
weights_info: ModelDeployWeightInfo = self.get_weight_cls()(self.config, tp_size, tp_rank)
return get_model_loader(weights_info, self.compute_dtype, self.database)
@staticmethod
def eval_model_size(config: GptInitModelParameters):
return config.eval_model_size()
@staticmethod
def eval_model_param_count(config: GptInitModelParameters):
return config.model_param_count