def hf_qwen_converter()

in maga_transformer/utils/smooth_quant_convert/qwen/hf_qwen_convert.py [0:0]


def hf_qwen_converter(args: ProgArgs, ret):
    infer_tp = args.tensor_parallelism
    multi_query_mode = True if args.model in ["santacoder", "starcoder"
                                              ] else False
    saved_dir = Path(args.out_dir) 
    saved_dir.mkdir(parents=True, exist_ok=True)

    # load position_embedding from rank 0
    model = AutoModelForCausalLM.from_pretrained(
        args.in_file,
        device_map=
        "auto",  # if you gpu memory is not enough, you can set device_map="cpu"
        trust_remote_code=True,
        torch_dtype=str_dtype_to_torch(args.storage_type),
    ).half()  # if you gpu memory is not enough, you can set .half() to .float()
    model.generation_config = GenerationConfig.from_pretrained(
        args.in_file, trust_remote_code=True)
    act_range = {}
    qwen_smoother = {}
    if args.smoothquant is not None or args.calibrate_kv_cache:
        os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
            "TOKENIZERS_PARALLELISM", "false")

        dataset = datasets.load_from_disk(args.dataset_cache_dir)
        tokenizer = AutoTokenizer.from_pretrained(
            args.in_file,
            legacy=False,
            padding_side='left',
            trust_remote_code=True,
        )
        gen_config_path = os.path.join(args.in_file, 'generation_config.json')
        with open(gen_config_path, 'r') as f:
            gen_config = json.load(f)
        chat_format = gen_config['chat_format']
        tokenizer.pad_token_id = tokenizer.im_end_id
        # use this prompt to make chat model do summarize
        system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
        act_range = capture_activation_range(
            model,
            tokenizer,
            dataset,
            system_prompt=system_prompt,
            chat_format=chat_format,
            max_input_len=args.max_input_len,
        )
        if args.smoothquant is not None:
            smooth_qwen_model(model, act_range, args.smoothquant, qwen_smoother)

    config = configparser.ConfigParser()
    config["qwen"] = {}
    for key in vars(args):
        config["qwen"][key] = f"{vars(args)[key]}"
    for k, v in vars(model.config).items():
        config["qwen"][k] = f"{v}"
    config["qwen"]["storage_dtype"] = args.storage_type
    config["qwen"]["multi_query_mode"] = str(multi_query_mode)
    with open(saved_dir / "smoothquant.ini", 'w') as configfile:
        config.write(configfile)

    storage_type = str_dtype_to_torch(args.storage_type)

    global_weights = ["transformer.wte.weight", "transformer.ln_f.weight", "lm_head.weight"]

    int8_outputs = None
    if args.calibrate_kv_cache:
        int8_outputs = "kv_cache_only"
    if args.smoothquant is not None:
        int8_outputs = "all"

    starmap_args = []
    for name, param in tqdm(
            model.named_parameters(),
            desc="convert and save",
            total=len(list(model.parameters())),
            ncols=80,
    ):
        if "weight" not in name and "bias" not in name:
            continue
        converted_name = convert_qwen_name(name)
        if name.replace(".weight", "") in qwen_smoother.keys():
            smoother = qwen_smoother[name.replace(".weight", "")]
            starmap_arg = (
                0,
                saved_dir,
                infer_tp,
                f"{converted_name}.smoother".replace(".weight", ""),
                smoother,
                storage_type,
                None,
                {
                    "int8_outputs": int8_outputs,
                    "multi_query_mode": multi_query_mode,
                    "local_dim": None,
                },
            )
            if args.processes > 1:
                starmap_args.append(starmap_arg)
            else:
                result = split_and_save_weight(*starmap_arg)
                ret.update(result)

        param = transpose_weights(name, param)
        if converted_name in global_weights:
            # torch_to_numpy(param.to(storage_type).cpu()).tofile(
            #     saved_dir / f"{converted_name}.bin")
            # ret[converted_name] = torch.from_numpy(np.asarray(param)).cpu()
            ret[converted_name] = param.to(storage_type).cpu() 
        else:
            if 'q_attn' in name:
                param = concat_qkv_weight_bias(param, name, model)
                converted_name = converted_name.replace("query",
                                                        "query_key_value")
            # Needed by QKV projection weight split. With multi_query_mode one does not simply take
            # out_dim and divide it by 3 to get local_dim because out_dim = local_dim + 2 * head_size
            local_dim = model.transformer.h[
                0].attn.embed_dim if multi_query_mode else None
            starmap_arg = (0, saved_dir, infer_tp, converted_name,
                           param.to(storage_type), storage_type,
                           act_range.get(name.replace(".weight", "")), {
                               "int8_outputs": int8_outputs,
                               "multi_query_mode": multi_query_mode,
                               "local_dim": local_dim
                           })
            if args.processes > 1:
                starmap_args.append(starmap_arg)
            else:
                result = split_and_save_weight(*starmap_arg)
                ret.update(result)

    if args.processes > 1:
        starmap_args = tqdm(starmap_args, desc="saving weights")
        with multiprocessing.Pool(args.processes) as pool:
            # pool.starmap(split_and_save_weight, starmap_args)
            results = pool.starmap(split_and_save_weight, starmap_args)
        for res in results:
            ret.update(res)
    return ret