def from_pretrained()

in optimum/habana/diffusers/pipelines/pipeline_utils.py [0:0]


    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
        """
        More information [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained).
        """

        # Set the correct log level depending on the node
        # Already done in super().init() but we have to do it again
        # because we use optimum.utils.logging here and not
        # diffusers.utils.logging
        log_level = kwargs.pop("log_level", logging.INFO)
        logging.set_verbosity(log_level)
        logging.enable_default_handler()
        logging.enable_explicit_format()

        # Import diffusers.pipelines.pipeline_utils to override the values of LOADABLE_CLASSES and ALL_IMPORTABLE_CLASSES
        import diffusers.pipelines.pipeline_utils

        diffusers.pipelines.pipeline_utils.LOADABLE_CLASSES = GAUDI_LOADABLE_CLASSES
        diffusers.pipelines.pipeline_utils.ALL_IMPORTABLE_CLASSES = GAUDI_ALL_IMPORTABLE_CLASSES

        # Define a new kwarg here to know in the __init__ whether to use full bf16 precision or not
        bf16_full_eval = kwargs.get("torch_dtype", None) == torch.bfloat16
        kwargs["bf16_full_eval"] = bf16_full_eval

        # Need to load custom ops lists before instantiating htcore
        if kwargs.get("gaudi_config", None) is not None:
            if isinstance(kwargs["gaudi_config"], str):
                gaudi_config = GaudiConfig.from_pretrained(kwargs["gaudi_config"])
            else:
                gaudi_config = kwargs["gaudi_config"]
            gaudi_config.declare_autocast_bf16_fp32_ops()
            kwargs["gaudi_config"] = gaudi_config

        # Import htcore here to support model quantization
        import habana_frameworks.torch.core as htcore  # noqa: F401

        # Normally we just need to return super().from_pretrained.  However this is a
        # workaround for Transformers 4.49.0 issue (sub_model torch_dtype option ignored).
        # Note this issue is already fixed in 4.50.0dev working branch..
        model = super().from_pretrained(
            pretrained_model_name_or_path,
            **kwargs,
        )
        if bf16_full_eval:
            # Get the component names
            component_names = [name for name in model.__dict__ if not name.startswith("_")]
            # Iterate through the component names and fix dtype
            for name in component_names:
                component = getattr(model, name, None)
                if component is not None and hasattr(component, "dtype"):
                    component.to(torch.bfloat16)

        return model