in src/diffusers/pipelines/pipeline_utils.py [0:0]
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
r"""
Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
Parameters:
pretrained_model_name (`str` or `os.PathLike`, *optional*):
A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
hosted on the Hub.
custom_pipeline (`str`, *optional*):
Can be either:
- A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained
pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines
the custom pipeline.
- A string, the *file name* of a community pipeline hosted on GitHub under
[Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
current `main` branch of GitHub.
- A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
must contain a file called `pipeline.py` that defines the custom pipeline.
<Tip warning={true}>
🧪 This is an experimental feature and may change in the future.
</Tip>
For more information on how to load and create custom pipelines, take a look at [How to contribute a
community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
custom_revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
dduf_file(`str`, *optional*):
Load weights from the specified DDUF file.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_onnx (`bool`, *optional*, defaults to `False`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.
Returns:
`os.PathLike`:
A path to the downloaded pipeline.
<Tip>
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
`huggingface-cli login`.
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
if dduf_file:
if custom_pipeline:
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
if load_connected_pipeline:
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
return _download_dduf_file(
pretrained_model_name=pretrained_model_name,
dduf_file=dduf_file,
pipeline_class_name=cls.__name__,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
)
allow_pickle = True if (use_safetensors is None or use_safetensors is False) else False
use_safetensors = use_safetensors if use_safetensors is not None else True
allow_patterns = None
ignore_patterns = None
model_info_call_error: Optional[Exception] = None
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only:
config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
token=token,
)
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)
filenames = set(filenames) - set(ignore_filenames)
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, filenames)
custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant
)
custom_class_name = None
if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
custom_pipeline = config_dict["_class_name"][0]
custom_class_name = config_dict["_class_name"][1]
load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
load_components_from_hub = len(custom_components) > 0
if load_pipe_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
repo_id=pretrained_model_name if load_pipe_from_hub else None,
hub_revision=revision,
class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
# retrieve the names of the folders containing model weights
model_folder_names = {
os.path.split(f)[0] for f in filter_model_files(filenames) if os.path.split(f)[0] in folder_names
}
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
filenames,
use_safetensors,
from_flax,
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
variant,
)
model_filenames, variant_filenames = variant_compatible_siblings(
filenames, variant=variant, ignore_patterns=ignore_patterns
)
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
# add custom component files
allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
# add custom pipeline file
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
cls.config_name,
CUSTOM_PIPELINE_FILE_NAME,
]
# Don't download any objects that are passed
allow_patterns = [
p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
]
if pipeline_class._load_connected_pipes:
allow_patterns.append("README.md")
# Don't download index files of forbidden patterns either
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
snapshot_folder = Path(config_file).parent
pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
if pipeline_is_cached and not force_download:
# if the pipeline is cached, we can directly return it
# else call snapshot_download
return snapshot_folder
user_agent = {"pipeline_class": cls.__name__}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
user_agent["custom_pipeline"] = custom_pipeline
# download all allow_patterns - ignore_patterns
try:
cached_folder = snapshot_download(
pretrained_model_name,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
for connected_pipe_repo_id in connected_pipes:
download_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"variant": variant,
"use_safetensors": use_safetensors,
}
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
return cached_folder
except FileNotFoundError:
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
# This can happen in two cases:
# 1. If the user passed `local_files_only=True` => we raise the error directly
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
if model_info_call_error is None:
# 1. user passed `local_files_only=True`
raise
else:
# 2. we forced `local_files_only=True` when `model_info` failed
raise EnvironmentError(
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
" above."
) from model_info_call_error