in maga_transformer/utils/smooth_quant_convert/llama/hf_llama_convert.py [0:0]
def hf_gpt_converter(args, ret):
infer_tp = args.tensor_parallelism
saved_dir = Path(args.out_dir)
saved_dir.mkdir(parents=True, exist_ok=True)
torch_dtype = torch.float16 if args.storage_type == 'fp16' else torch.float32
model = LlamaForCausalLM.from_pretrained(args.in_file,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True)
if args.load_model_on_cpu:
model = model.float()
model = model.cpu()
torch.cuda.empty_cache()
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
if args.smoothquant is not None or args.calibrate_kv_cache:
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false")
# dataset = load_dataset("ccdv/cnn_dailymail",
# '3.0.0',
# cache_dir=args.dataset_cache_dir)
dataset = datasets.load_from_disk(args.dataset_cache_dir)
act_range = capture_activation_range(
model,
LlamaTokenizer.from_pretrained(args.in_file, padding_side='left'),
dataset)
if args.smoothquant is not None:
smooth_llama_model(model, act_range, args.smoothquant,
llama_qkv_para, llama_smoother)
args.multi_query_mode = model.config.num_attention_heads != model.config.num_key_value_heads
config = configparser.ConfigParser()
config["llama"] = {}
for key in vars(args):
config["llama"][key] = f"{vars(args)[key]}"
for k, v in vars(model.config).items():
config["llama"][k] = f"{v}"
config["llama"]["weight_data_type"] = args.storage_type
config["llama"]["multi_query_mode"] = str(args.multi_query_mode)
with open(saved_dir / "smoothquant.ini", 'w') as configfile:
config.write(configfile)
storage_type = str_to_np_dtype(args.storage_type)
global_bin_weights = [
'model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight'
]
int8_outputs = None
if args.calibrate_kv_cache:
int8_outputs = "kv_cache_only"
if args.smoothquant is not None:
int8_outputs = "all"
starmap_args = []
for name, param in model.named_parameters():
if "weight" not in name and "bias" not in name:
continue
bin_name = gpt_to_bin_name(name)
if args.convert_model_on_cpu:
param = param.cpu()
if name.replace(".weight", "") in llama_smoother.keys():
smoother = llama_smoother[name.replace(".weight", "")]
smoother = smoother.detach().cpu().numpy()
starmap_args.append((0, saved_dir, infer_tp,
f"{bin_name}.smoother".replace(".weight", ""),
smoother, None, {
"int8_outputs": int8_outputs,
"multi_query_mode": args.multi_query_mode,
"local_dim": None,
}))
param = transpose_weights(name, param)
param = param.detach().cpu().numpy().astype(storage_type)
if bin_name in global_bin_weights:
ret[bin_name] = torch.from_numpy(np.asarray(param))
# param.tofile(saved_dir / f"{bin_name}.bin")
elif bin_name.split('.')[-2] == 'query_key_value':
# Is there other ways to get local_dim? local_dim = hidden_size in llama2
local_dim = model.config.hidden_size if args.multi_query_mode else None
if args.smoothquant is None:
merge_qkv_scales(name, model, act_range, llama_qkv_para)
qkv = (0, saved_dir, infer_tp, bin_name,
llama_qkv_para.get(
name.replace(".weight", "").replace(
".q_proj",
".qkv_proj")).cpu().numpy().astype(storage_type),
act_range.get(
name.replace(".weight",
"").replace(".q_proj", ".qkv_proj")), {
"int8_outputs": int8_outputs,
"multi_query_mode":
args.multi_query_mode,
"local_dim": local_dim,
})
starmap_args.append(qkv)
elif bin_name.split('.')[-2] == 'kv':
continue
else:
starmap_args.append((0, saved_dir, infer_tp, bin_name, param,
act_range.get(name.replace(".weight", "")), {
"int8_outputs": int8_outputs,
"multi_query_mode": args.multi_query_mode,
"local_dim": None,
}))
starmap_args = tqdm(starmap_args, desc="saving weights")
if args.processes > 1:
with multiprocessing.Pool(args.processes) as pool:
results = pool.starmap(split_and_save_weight, starmap_args)
for res in results:
ret.update(res)
else:
# simpler for debug situations
for starmap_arg in starmap_args:
result = split_and_save_weight(*starmap_arg)
ret.update(result)
return ret