in maga_transformer/utils/smooth_quant_convert/qwen/hf_qwen_convert.py [0:0]
def hf_qwen_converter(args: ProgArgs, ret):
infer_tp = args.tensor_parallelism
multi_query_mode = True if args.model in ["santacoder", "starcoder"
] else False
saved_dir = Path(args.out_dir)
saved_dir.mkdir(parents=True, exist_ok=True)
# load position_embedding from rank 0
model = AutoModelForCausalLM.from_pretrained(
args.in_file,
device_map=
"auto", # if you gpu memory is not enough, you can set device_map="cpu"
trust_remote_code=True,
torch_dtype=str_dtype_to_torch(args.storage_type),
).half() # if you gpu memory is not enough, you can set .half() to .float()
model.generation_config = GenerationConfig.from_pretrained(
args.in_file, trust_remote_code=True)
act_range = {}
qwen_smoother = {}
if args.smoothquant is not None or args.calibrate_kv_cache:
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false")
dataset = datasets.load_from_disk(args.dataset_cache_dir)
tokenizer = AutoTokenizer.from_pretrained(
args.in_file,
legacy=False,
padding_side='left',
trust_remote_code=True,
)
gen_config_path = os.path.join(args.in_file, 'generation_config.json')
with open(gen_config_path, 'r') as f:
gen_config = json.load(f)
chat_format = gen_config['chat_format']
tokenizer.pad_token_id = tokenizer.im_end_id
# use this prompt to make chat model do summarize
system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
act_range = capture_activation_range(
model,
tokenizer,
dataset,
system_prompt=system_prompt,
chat_format=chat_format,
max_input_len=args.max_input_len,
)
if args.smoothquant is not None:
smooth_qwen_model(model, act_range, args.smoothquant, qwen_smoother)
config = configparser.ConfigParser()
config["qwen"] = {}
for key in vars(args):
config["qwen"][key] = f"{vars(args)[key]}"
for k, v in vars(model.config).items():
config["qwen"][k] = f"{v}"
config["qwen"]["storage_dtype"] = args.storage_type
config["qwen"]["multi_query_mode"] = str(multi_query_mode)
with open(saved_dir / "smoothquant.ini", 'w') as configfile:
config.write(configfile)
storage_type = str_dtype_to_torch(args.storage_type)
global_weights = ["transformer.wte.weight", "transformer.ln_f.weight", "lm_head.weight"]
int8_outputs = None
if args.calibrate_kv_cache:
int8_outputs = "kv_cache_only"
if args.smoothquant is not None:
int8_outputs = "all"
starmap_args = []
for name, param in tqdm(
model.named_parameters(),
desc="convert and save",
total=len(list(model.parameters())),
ncols=80,
):
if "weight" not in name and "bias" not in name:
continue
converted_name = convert_qwen_name(name)
if name.replace(".weight", "") in qwen_smoother.keys():
smoother = qwen_smoother[name.replace(".weight", "")]
starmap_arg = (
0,
saved_dir,
infer_tp,
f"{converted_name}.smoother".replace(".weight", ""),
smoother,
storage_type,
None,
{
"int8_outputs": int8_outputs,
"multi_query_mode": multi_query_mode,
"local_dim": None,
},
)
if args.processes > 1:
starmap_args.append(starmap_arg)
else:
result = split_and_save_weight(*starmap_arg)
ret.update(result)
param = transpose_weights(name, param)
if converted_name in global_weights:
# torch_to_numpy(param.to(storage_type).cpu()).tofile(
# saved_dir / f"{converted_name}.bin")
# ret[converted_name] = torch.from_numpy(np.asarray(param)).cpu()
ret[converted_name] = param.to(storage_type).cpu()
else:
if 'q_attn' in name:
param = concat_qkv_weight_bias(param, name, model)
converted_name = converted_name.replace("query",
"query_key_value")
# Needed by QKV projection weight split. With multi_query_mode one does not simply take
# out_dim and divide it by 3 to get local_dim because out_dim = local_dim + 2 * head_size
local_dim = model.transformer.h[
0].attn.embed_dim if multi_query_mode else None
starmap_arg = (0, saved_dir, infer_tp, converted_name,
param.to(storage_type), storage_type,
act_range.get(name.replace(".weight", "")), {
"int8_outputs": int8_outputs,
"multi_query_mode": multi_query_mode,
"local_dim": local_dim
})
if args.processes > 1:
starmap_args.append(starmap_arg)
else:
result = split_and_save_weight(*starmap_arg)
ret.update(result)
if args.processes > 1:
starmap_args = tqdm(starmap_args, desc="saving weights")
with multiprocessing.Pool(args.processes) as pool:
# pool.starmap(split_and_save_weight, starmap_args)
results = pool.starmap(split_and_save_weight, starmap_args)
for res in results:
ret.update(res)
return ret