def apply_delta_low_cpu_mem()

in fastchat/model/apply_delta.py [0:0]


def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
    delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
    delta_config = AutoConfig.from_pretrained(delta_path)

    if os.path.exists(target_model_path):
        shutil.rmtree(target_model_path)
    os.makedirs(target_model_path)

    split_size = 4 * GB

    with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
        print(f"Split files for the base model to {tmp_base_path}")
        split_files(base_model_path, tmp_base_path, split_size)
        print(f"Split files for the delta weights to {tmp_delta_path}")
        split_files(delta_path, tmp_delta_path, split_size)

        base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
        base_files = glob.glob(base_pattern)
        delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
        delta_files = glob.glob(delta_pattern)
        delta_state_dict = torch.load(delta_files[0])

        print("Applying the delta")
        weight_map = {}
        total_size = 0

        for i, base_file in tqdm(enumerate(base_files)):
            state_dict = torch.load(base_file)
            file_name = f"pytorch_model-{i}.bin"
            for name, param in state_dict.items():
                if name not in delta_state_dict:
                    for delta_file in delta_files:
                        delta_state_dict = torch.load(delta_file)
                        gc.collect()
                        if name in delta_state_dict:
                            break

                state_dict[name] += delta_state_dict[name]
                weight_map[name] = file_name
                total_size += param.numel() * param.element_size()
                gc.collect()
            torch.save(state_dict, os.path.join(target_model_path, file_name))

        with open(
            os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
        ) as f:
            json.dump(
                {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
            )

    print(f"Saving the target model to {target_model_path}")
    delta_tokenizer.save_pretrained(target_model_path)
    delta_config.save_pretrained(target_model_path)