def split_files()

in fastchat/model/apply_delta.py [0:0]


def split_files(model_path, tmp_path, split_size):
    if not os.path.exists(model_path):
        model_path = snapshot_download(repo_id=model_path)
    if not os.path.exists(tmp_path):
        os.makedirs(tmp_path)

    file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
    files = glob.glob(file_pattern)

    part = 0
    try:
        for file_path in tqdm(files):
            state_dict = torch.load(file_path)
            new_state_dict = {}

            current_size = 0
            for name, param in state_dict.items():
                param_size = param.numel() * param.element_size()

                if current_size + param_size > split_size:
                    new_file_name = f"pytorch_model-{part}.bin"
                    new_file_path = os.path.join(tmp_path, new_file_name)
                    torch.save(new_state_dict, new_file_path)
                    current_size = 0
                    new_state_dict = None
                    gc.collect()
                    new_state_dict = {}
                    part += 1

                new_state_dict[name] = param
                current_size += param_size

            new_file_name = f"pytorch_model-{part}.bin"
            new_file_path = os.path.join(tmp_path, new_file_name)
            torch.save(new_state_dict, new_file_path)
            new_state_dict = None
            gc.collect()
            new_state_dict = {}
            part += 1
    except Exception as e:
        print(f"An error occurred during split_files: {e}")
        shutil.rmtree(tmp_path)
        raise