def is_safetensors_compatible()

in src/diffusers/pipelines/pipeline_loading_utils.py [0:0]


def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
    """
    Checking for safetensors compatibility:
    - The model is safetensors compatible only if there is a safetensors file for each model component present in
      filenames.

    Converting default pytorch serialized filenames to safetensors serialized filenames:
    - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
    - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
      extension is replaced with ".safetensors"
    """
    weight_names = [
        WEIGHTS_NAME,
        SAFETENSORS_WEIGHTS_NAME,
        FLAX_WEIGHTS_NAME,
        ONNX_WEIGHTS_NAME,
        ONNX_EXTERNAL_WEIGHTS_NAME,
    ]

    if is_transformers_available():
        weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]

    # model_pytorch, diffusion_model_pytorch, ...
    weight_prefixes = [w.split(".")[0] for w in weight_names]
    # .bin, .safetensors, ...
    weight_suffixs = [w.split(".")[-1] for w in weight_names]
    # -00001-of-00002
    transformers_index_format = r"\d{5}-of-\d{5}"
    # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
    variant_file_re = re.compile(
        rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
    )
    non_variant_file_re = re.compile(
        rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
    )

    passed_components = passed_components or []
    if folder_names:
        filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}

    # extract all components of the pipeline and their associated files
    components = {}
    for filename in filenames:
        if not len(filename.split("/")) == 2:
            continue

        component, component_filename = filename.split("/")
        if component in passed_components:
            continue

        components.setdefault(component, [])
        components[component].append(component_filename)

    # If there are no component folders check the main directory for safetensors files
    filtered_filenames = set()
    if not components:
        if variant is not None:
            filtered_filenames = filter_with_regex(filenames, variant_file_re)

        # If no variant filenames exist check if non-variant files are available
        if not filtered_filenames:
            filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
        return any(".safetensors" in filename for filename in filtered_filenames)

    # iterate over all files of a component
    # check if safetensor files exist for that component
    for component, component_filenames in components.items():
        matches = []
        filtered_component_filenames = set()
        # if variant is provided check if the variant of the safetensors exists
        if variant is not None:
            filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)

        # if variant safetensor files do not exist check for non-variants
        if not filtered_component_filenames:
            filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
        for component_filename in filtered_component_filenames:
            filename, extension = os.path.splitext(component_filename)

            match_exists = extension == ".safetensors"
            matches.append(match_exists)

        if not any(matches):
            return False

    return True