in src/diffusers/pipelines/pipeline_loading_utils.py [0:0]
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
filenames.
Converting default pytorch serialized filenames to safetensors serialized filenames:
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = r"\d{5}-of-\d{5}"
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
passed_components = passed_components or []
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
components = {}
for filename in filenames:
if not len(filename.split("/")) == 2:
continue
component, component_filename = filename.split("/")
if component in passed_components:
continue
components.setdefault(component, [])
components[component].append(component_filename)
# If there are no component folders check the main directory for safetensors files
filtered_filenames = set()
if not components:
if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re)
# If no variant filenames exist check if non-variant files are available
if not filtered_filenames:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames)
# iterate over all files of a component
# check if safetensor files exist for that component
for component, component_filenames in components.items():
matches = []
filtered_component_filenames = set()
# if variant is provided check if the variant of the safetensors exists
if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
# if variant safetensor files do not exist check for non-variants
if not filtered_component_filenames:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename)
match_exists = extension == ".safetensors"
matches.append(match_exists)
if not any(matches):
return False
return True