in docker_images/diffusers/app/lora.py [0:0]
def _load_lora_adapter(self, kwargs):
adapter = kwargs.pop("lora_adapter", None)
if adapter is not None:
logger.info("LoRA adapter %s requested", adapter)
if adapter != self.current_lora_adapter:
model_data = self._hub_model_info(adapter)
if not self._is_lora(model_data):
msg = f"Requested adapter {adapter:s} is not a LoRA adapter"
logger.error(msg)
raise ValueError(msg)
base_model = model_data.cardData["base_model"]
is_list = isinstance(base_model, list)
if (is_list and (self.model_id not in base_model)) or (
not is_list and self.model_id != base_model
):
msg = f"Requested adapter {adapter:s} is not a LoRA adapter for base model {self.model_id:s}"
logger.error(msg)
raise ValueError(msg)
logger.info(
"LoRA adapter %s needs to be replaced with compatible adapter %s",
self.current_lora_adapter,
adapter,
)
if self.current_lora_adapter is not None:
self.ldm.unfuse_lora()
self.ldm.unload_lora_weights()
self._unload_textual_embeddings()
self.current_lora_adapter = None
logger.info("LoRA weights unloaded, loading new weights")
weight_name = self._get_lora_weight_name(model_data=model_data)
self.ldm.load_lora_weights(
adapter, weight_name=weight_name, use_auth_token=self.use_auth_token
)
self.current_lora_adapter = adapter
self._fuse_or_raise()
logger.info("LoRA weights loaded for adapter %s", adapter)
self._load_textual_embeddings(adapter, model_data)
else:
logger.info("LoRA adapter %s already loaded", adapter)
# Needed while a LoRA is loaded w/ model
model_data = self._hub_model_info(adapter)
if (
self._is_pivotal_tuning_lora(model_data)
and self.current_tokens_loaded == 0
):
self._load_textual_embeddings(adapter, model_data)
elif self.current_lora_adapter is not None:
logger.info(
"No LoRA adapter requested, unloading weights and using base model %s",
self.model_id,
)
self.ldm.unfuse_lora()
self.ldm.unload_lora_weights()
self._unload_textual_embeddings()
self.current_lora_adapter = None