in optimum/habana/checkpoint_utils.py [0:0]
def get_repo_root(model_name_or_path, local_rank=-1, token=None):
"""
Downloads the specified model checkpoint and returns the repository where it was downloaded.
"""
if Path(model_name_or_path).is_dir():
# If it is a local model, no need to download anything
return model_name_or_path
else:
# Checks if online or not
if is_offline_mode():
if local_rank == 0:
print("Offline mode: forcing local_files_only=True")
# Only download PyTorch weights by default
if any(
".safetensors" in filename for filename in list_repo_files(model_name_or_path, token=token)
): # Some models like Falcon-180b are in only safetensors format
allow_patterns = ["*.safetensors"]
elif any(".bin" in filename for filename in list_repo_files(model_name_or_path, token=token)):
allow_patterns = ["*.bin"]
else:
raise TypeError("Only PyTorch models are supported")
# Download only on first process
if local_rank in [-1, 0]:
cache_dir = snapshot_download(
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=allow_patterns,
max_workers=16,
token=token,
)
if local_rank == -1:
# If there is only one process, then the method is finished
return cache_dir
# Make all processes wait so that other processes can get the checkpoint directly from cache
if torch.distributed.is_initialized():
torch.distributed.barrier()
return snapshot_download(
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=allow_patterns,
token=token,
)