in docker_images/diffusers/app/pipelines/text_to_image.py [0:0]
def __init__(self, model_id: str):
self.current_lora_adapter = None
self.model_id = None
self.current_tokens_loaded = 0
self.use_auth_token = os.getenv("HF_API_TOKEN")
# This should allow us to make the image work with private models when no token is provided, if the said model
# is already in local cache
self.offline_preferred = validation.str_to_bool(os.getenv("OFFLINE_PREFERRED"))
model_data = self._hub_model_info(model_id)
kwargs = (
{"safety_checker": None}
if model_id.startswith("hf-internal-testing/")
else {}
)
env_dtype = os.getenv("TORCH_DTYPE")
if env_dtype:
kwargs["torch_dtype"] = getattr(torch, env_dtype)
elif torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16
has_model_index = any(
file.rfilename == "model_index.json" for file in model_data.siblings
)
if self._is_lora(model_data):
model_type = "LoraModel"
elif has_model_index:
config_file = self._hub_repo_file(model_id, "model_index.json")
with open(config_file, "r") as f:
config_dict = json.load(f)
model_type = config_dict.get("_class_name", None)
else:
raise ValueError("Model type not found")
if model_type == "LoraModel":
model_to_load = model_data.cardData["base_model"]
self.model_id = model_to_load
if not model_to_load:
raise ValueError(
"No `base_model` found. Please include a `base_model` on your README.md tags"
)
self._load_sd_with_sdxl_fix(model_to_load, **kwargs)
# The lora will actually be lazily loaded on the fly per request
self.current_lora_adapter = None
else:
if model_id == "stabilityai/stable-diffusion-xl-base-1.0":
self._load_sd_with_sdxl_fix(model_id, **kwargs)
else:
self.ldm = AutoPipelineForText2Image.from_pretrained(
model_id, use_auth_token=self.use_auth_token, **kwargs
)
self.model_id = model_id
self.is_karras_compatible = (
self.ldm.__class__.__init__.__annotations__.get("scheduler", None)
== KarrasDiffusionSchedulers
)
if self.is_karras_compatible:
self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.ldm.scheduler.config
)
self.default_scheduler = self.ldm.scheduler
if not idle.UNLOAD_IDLE:
self._model_to_gpu()