optimum/intel/neural_compressor/modeling_diffusion.py (37 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from diffusers import StableDiffusionPipeline
from neural_compressor.utils.pytorch import load
from ..utils.constant import DIFFUSION_WEIGHTS_NAME, WEIGHTS_NAME
from ..utils.import_utils import _torch_version, is_torch_version
from .configuration import INCConfig
class INCStableDiffusionPipeline(StableDiffusionPipeline):
@classmethod
def from_pretrained(cls, *args, **kwargs):
model = super(INCStableDiffusionPipeline, cls).from_pretrained(*args, low_cpu_mem_usage=False, **kwargs)
components = set(model.config.keys()).intersection({"vae", "text_encoder", "unet"})
for name in components:
component = getattr(model, name, None)
name_or_path = ""
if hasattr(component, "_internal_dict"):
name_or_path = component._internal_dict["_name_or_path"]
elif hasattr(component, "name_or_path"):
name_or_path = component.name_or_path
if os.path.isdir(name_or_path):
folder_contents = os.listdir(name_or_path)
file_name = DIFFUSION_WEIGHTS_NAME if DIFFUSION_WEIGHTS_NAME in folder_contents else WEIGHTS_NAME
state_dict_path = os.path.join(name_or_path, file_name)
if os.path.exists(state_dict_path) and INCConfig.CONFIG_NAME in folder_contents:
msg = None
inc_config = INCConfig.from_pretrained(name_or_path)
if not is_torch_version("==", inc_config.torch_version):
msg = f"Quantized model was obtained with torch version {inc_config.torch_version} but {_torch_version} was found."
state_dict = torch.load(state_dict_path, map_location="cpu")
if "best_configure" in state_dict and state_dict["best_configure"] is not None:
try:
load(state_dict_path, component)
except Exception as e:
if msg is not None:
e.args += (msg,)
raise
return model