in optimum/habana/diffusers/pipelines/pipeline_utils.py [0:0]
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
"""
More information [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained).
"""
# Set the correct log level depending on the node
# Already done in super().init() but we have to do it again
# because we use optimum.utils.logging here and not
# diffusers.utils.logging
log_level = kwargs.pop("log_level", logging.INFO)
logging.set_verbosity(log_level)
logging.enable_default_handler()
logging.enable_explicit_format()
# Import diffusers.pipelines.pipeline_utils to override the values of LOADABLE_CLASSES and ALL_IMPORTABLE_CLASSES
import diffusers.pipelines.pipeline_utils
diffusers.pipelines.pipeline_utils.LOADABLE_CLASSES = GAUDI_LOADABLE_CLASSES
diffusers.pipelines.pipeline_utils.ALL_IMPORTABLE_CLASSES = GAUDI_ALL_IMPORTABLE_CLASSES
# Define a new kwarg here to know in the __init__ whether to use full bf16 precision or not
bf16_full_eval = kwargs.get("torch_dtype", None) == torch.bfloat16
kwargs["bf16_full_eval"] = bf16_full_eval
# Need to load custom ops lists before instantiating htcore
if kwargs.get("gaudi_config", None) is not None:
if isinstance(kwargs["gaudi_config"], str):
gaudi_config = GaudiConfig.from_pretrained(kwargs["gaudi_config"])
else:
gaudi_config = kwargs["gaudi_config"]
gaudi_config.declare_autocast_bf16_fp32_ops()
kwargs["gaudi_config"] = gaudi_config
# Import htcore here to support model quantization
import habana_frameworks.torch.core as htcore # noqa: F401
# Normally we just need to return super().from_pretrained. However this is a
# workaround for Transformers 4.49.0 issue (sub_model torch_dtype option ignored).
# Note this issue is already fixed in 4.50.0dev working branch..
model = super().from_pretrained(
pretrained_model_name_or_path,
**kwargs,
)
if bf16_full_eval:
# Get the component names
component_names = [name for name in model.__dict__ if not name.startswith("_")]
# Iterate through the component names and fix dtype
for name in component_names:
component = getattr(model, name, None)
if component is not None and hasattr(component, "dtype"):
component.to(torch.bfloat16)
return model