in optimum/fx/parallelization/utils.py [0:0]
def try_collect_weight_map(model_name_or_path: str, cache_dir: Optional[str], folder_path: str) -> Dict[str, str]:
"""Try collecting weight mapping information from the model folder."""
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
weight_map = {}
use_safetensors, weight_patterns = False, ["*safetensors", "*.bin"]
for pattern in weight_patterns:
if len(glob.glob(os.path.join(folder_path, pattern))) > 0:
use_safetensors = pattern == "*.safetensors"
break
index_path = os.path.join(folder_path, SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME)
weight_files = glob.glob(os.path.join(folder_path, "*.safetensors" if use_safetensors else "*.bin"))
if os.path.isfile(index_path):
with open(index_path) as f:
index_dict = json.load(f)
weight_map = {k: os.path.join(folder_path, v) for k, v in index_dict["weight_map"].items()}
# convert bin files to safetensors, modify `weight_map` meanwhile
if not use_safetensors:
convert_bin_to_safetensors(model_name_or_path, cache_dir, weight_files, weight_map)
# last resort: try directly construct weight_map from weight files
if not weight_map:
from safetensors import safe_open
# should have safetensors on disk in any case
weight_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
for weight_file in weight_files:
with safe_open(filename=weight_file, framework="pt") as f:
for key in f.keys():
weight_map[key] = weight_file
return weight_map