def try_collect_weight_map()

in optimum/fx/parallelization/utils.py [0:0]


def try_collect_weight_map(model_name_or_path: str, cache_dir: Optional[str], folder_path: str) -> Dict[str, str]:
    """Try collecting weight mapping information from the model folder."""
    from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME

    weight_map = {}
    use_safetensors, weight_patterns = False, ["*safetensors", "*.bin"]
    for pattern in weight_patterns:
        if len(glob.glob(os.path.join(folder_path, pattern))) > 0:
            use_safetensors = pattern == "*.safetensors"
            break
    index_path = os.path.join(folder_path, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME)
    weight_files = glob.glob(os.path.join(folder_path, "*.safetensors" if use_safetensors else "*.bin"))

    if os.path.isfile(index_path):
        with open(index_path) as f:
            index_dict = json.load(f)
        weight_map = {k: os.path.join(folder_path, v) for k, v in index_dict["weight_map"].items()}

    # convert bin files to safetensors, modify `weight_map` meanwhile
    if not use_safetensors:
        convert_bin_to_safetensors(model_name_or_path, cache_dir, weight_files, weight_map)

    # last resort: try directly construct weight_map from weight files
    if not weight_map:
        from safetensors import safe_open

        # should have safetensors on disk in any case
        weight_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
        for weight_file in weight_files:
            with safe_open(filename=weight_file, framework="pt") as f:
                for key in f.keys():
                    weight_map[key] = weight_file
    return weight_map