def hf_gpt_converter()

in maga_transformer/utils/smooth_quant_convert/llama/hf_llama_convert.py [0:0]


def hf_gpt_converter(args, ret):
    infer_tp = args.tensor_parallelism
    saved_dir = Path(args.out_dir)
    saved_dir.mkdir(parents=True, exist_ok=True)

    torch_dtype = torch.float16 if args.storage_type == 'fp16' else torch.float32
    model = LlamaForCausalLM.from_pretrained(args.in_file,
                                             torch_dtype=torch_dtype,
                                             device_map="auto",
                                             trust_remote_code=True)

    if args.load_model_on_cpu:
        model = model.float()
        model = model.cpu()
        torch.cuda.empty_cache()

    act_range = {}
    llama_qkv_para = {}
    # smoother for inputs of self_attn.o_proj and mlp.down_proj
    llama_smoother = {}

    if args.smoothquant is not None or args.calibrate_kv_cache:
        os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
            "TOKENIZERS_PARALLELISM", "false")
        # dataset = load_dataset("ccdv/cnn_dailymail",
        #                        '3.0.0',
        #                        cache_dir=args.dataset_cache_dir)
        dataset = datasets.load_from_disk(args.dataset_cache_dir)
        act_range = capture_activation_range(
            model,
            LlamaTokenizer.from_pretrained(args.in_file, padding_side='left'),
            dataset)
        if args.smoothquant is not None:
            smooth_llama_model(model, act_range, args.smoothquant,
                               llama_qkv_para, llama_smoother)

    args.multi_query_mode = model.config.num_attention_heads != model.config.num_key_value_heads
    config = configparser.ConfigParser()
    config["llama"] = {}
    for key in vars(args):
        config["llama"][key] = f"{vars(args)[key]}"
    for k, v in vars(model.config).items():
        config["llama"][k] = f"{v}"
    config["llama"]["weight_data_type"] = args.storage_type
    config["llama"]["multi_query_mode"] = str(args.multi_query_mode)
    with open(saved_dir / "smoothquant.ini", 'w') as configfile:
        config.write(configfile)

    storage_type = str_to_np_dtype(args.storage_type)

    global_bin_weights = [
        'model.embed_tokens.weight', 'model.norm.weight', 'lm_head.weight'
    ]

    int8_outputs = None
    if args.calibrate_kv_cache:
        int8_outputs = "kv_cache_only"
    if args.smoothquant is not None:
        int8_outputs = "all"

    starmap_args = []
    for name, param in model.named_parameters():
        if "weight" not in name and "bias" not in name:
            continue
        bin_name = gpt_to_bin_name(name)

        if args.convert_model_on_cpu:
            param = param.cpu()
        if name.replace(".weight", "") in llama_smoother.keys():
            smoother = llama_smoother[name.replace(".weight", "")]
            smoother = smoother.detach().cpu().numpy()
            starmap_args.append((0, saved_dir, infer_tp,
                                 f"{bin_name}.smoother".replace(".weight", ""),
                                 smoother, None, {
                                     "int8_outputs": int8_outputs,
                                     "multi_query_mode": args.multi_query_mode,
                                     "local_dim": None,
                                 }))

        param = transpose_weights(name, param)

        param = param.detach().cpu().numpy().astype(storage_type)

        if bin_name in global_bin_weights:
            ret[bin_name] = torch.from_numpy(np.asarray(param))
            # param.tofile(saved_dir / f"{bin_name}.bin")
        elif bin_name.split('.')[-2] == 'query_key_value':
            # Is there other ways to get local_dim? local_dim = hidden_size in llama2
            local_dim = model.config.hidden_size if args.multi_query_mode else None
            if args.smoothquant is None:
                merge_qkv_scales(name, model, act_range, llama_qkv_para)
            qkv = (0, saved_dir, infer_tp, bin_name,
                   llama_qkv_para.get(
                       name.replace(".weight", "").replace(
                           ".q_proj",
                           ".qkv_proj")).cpu().numpy().astype(storage_type),
                   act_range.get(
                       name.replace(".weight",
                                    "").replace(".q_proj", ".qkv_proj")), {
                                        "int8_outputs": int8_outputs,
                                        "multi_query_mode":
                                        args.multi_query_mode,
                                        "local_dim": local_dim,
                                    })
            starmap_args.append(qkv)
        elif bin_name.split('.')[-2] == 'kv':
            continue
        else:
            starmap_args.append((0, saved_dir, infer_tp, bin_name, param,
                                 act_range.get(name.replace(".weight", "")), {
                                     "int8_outputs": int8_outputs,
                                     "multi_query_mode": args.multi_query_mode,
                                     "local_dim": None,
                                 }))

    starmap_args = tqdm(starmap_args, desc="saving weights")

    if args.processes > 1:
        with multiprocessing.Pool(args.processes) as pool:
            results = pool.starmap(split_and_save_weight, starmap_args)
        for res in results:
            ret.update(res)
    else:
        # simpler for debug situations
        for starmap_arg in starmap_args:
            result = split_and_save_weight(*starmap_arg)
            ret.update(result)
    return ret