docker_images/diffusers/app/lora.py (186 lines of code) (raw):

import logging import torch.nn as nn from app import offline from safetensors.torch import load_file logger = logging.getLogger(__name__) class LoRAPipelineMixin(offline.OfflineBestEffortMixin): @staticmethod def _get_lora_weight_name(model_data): weight_name_candidate = LoRAPipelineMixin._lora_weights_candidates(model_data) if weight_name_candidate: return weight_name_candidate file_to_load = next( ( file.rfilename for file in model_data.siblings if file.rfilename.endswith(".safetensors") ), None, ) if not file_to_load and not weight_name_candidate: raise ValueError("No *.safetensors file found for your LoRA") return file_to_load @staticmethod def _is_lora(model_data): return LoRAPipelineMixin._lora_weights_candidates(model_data) or ( model_data.cardData.get("tags") and "lora" in model_data.cardData.get("tags", []) ) @staticmethod def _lora_weights_candidates(model_data): candidate = None for file in model_data.siblings: rfilename = str(file.rfilename) if rfilename.endswith("pytorch_lora_weights.bin"): candidate = rfilename elif rfilename.endswith("pytorch_lora_weights.safetensors"): candidate = rfilename break return candidate @staticmethod def _is_safetensors_pivotal(model_data): embeddings_safetensors_exists = any( sibling.rfilename == "embeddings.safetensors" for sibling in model_data.siblings ) return embeddings_safetensors_exists @staticmethod def _is_pivotal_tuning_lora(model_data): return LoRAPipelineMixin._is_safetensors_pivotal(model_data) or any( sibling.rfilename == "embeddings.pti" for sibling in model_data.siblings ) def _fuse_or_raise(self): try: self.ldm.fuse_lora(safe_fusing=True) except Exception as e: logger.exception(e) logger.warning("Unable to fuse LoRA adapter") self.ldm.unload_lora_weights() self.current_lora_adapter = None raise @staticmethod def _reset_tokenizer_and_encoder(tokenizer, text_encoder, token_to_remove): token_id = tokenizer(token_to_remove)["input_ids"][1] del tokenizer._added_tokens_decoder[token_id] del tokenizer._added_tokens_encoder[token_to_remove] tokenizer._update_trie() tokenizer_size = len(tokenizer) text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim text_embedding_weights = text_encoder.get_input_embeddings().weight[ :tokenizer_size ] text_embeddings_filtered = nn.Embedding(tokenizer_size, text_embedding_dim) text_embeddings_filtered.weight.data = text_embedding_weights text_encoder.set_input_embeddings(text_embeddings_filtered) def _unload_textual_embeddings(self): if self.current_tokens_loaded > 0: for i in range(self.current_tokens_loaded): token_to_remove = f"<s{i}>" self._reset_tokenizer_and_encoder( self.ldm.tokenizer, self.ldm.text_encoder, token_to_remove ) self._reset_tokenizer_and_encoder( self.ldm.tokenizer_2, self.ldm.text_encoder_2, token_to_remove ) self.current_tokens_loaded = 0 def _load_textual_embeddings(self, adapter, model_data): if self._is_pivotal_tuning_lora(model_data): embedding_path = self._hub_repo_file( repo_id=adapter, filename="embeddings.safetensors" if self._is_safetensors_pivotal(model_data) else "embeddings.pti", repo_type="model", ) embeddings = load_file(embedding_path) state_dict_clip_l = ( embeddings.get("text_encoders_0") if "text_encoders_0" in embeddings else embeddings.get("clip_l", None) ) state_dict_clip_g = ( embeddings.get("text_encoders_1") if "text_encoders_1" in embeddings else embeddings.get("clip_g", None) ) tokens_to_add = 0 if state_dict_clip_l is None else len(state_dict_clip_l) tokens_to_add_2 = 0 if state_dict_clip_g is None else len(state_dict_clip_g) if tokens_to_add == tokens_to_add_2 and tokens_to_add > 0: if state_dict_clip_l is not None and len(state_dict_clip_l) > 0: token_list = [f"<s{i}>" for i in range(tokens_to_add)] self.ldm.load_textual_inversion( state_dict_clip_l, token=token_list, text_encoder=self.ldm.text_encoder, tokenizer=self.ldm.tokenizer, ) if state_dict_clip_g is not None and len(state_dict_clip_g) > 0: token_list = [f"<s{i}>" for i in range(tokens_to_add_2)] self.ldm.load_textual_inversion( state_dict_clip_g, token=token_list, text_encoder=self.ldm.text_encoder_2, tokenizer=self.ldm.tokenizer_2, ) logger.info("Text embeddings loaded for adapter %s", adapter) else: logger.info( "No text embeddings were loaded due to invalid embeddings or a mismatch of token sizes " "for adapter %s", adapter, ) self.current_tokens_loaded = tokens_to_add def _load_lora_adapter(self, kwargs): adapter = kwargs.pop("lora_adapter", None) if adapter is not None: logger.info("LoRA adapter %s requested", adapter) if adapter != self.current_lora_adapter: model_data = self._hub_model_info(adapter) if not self._is_lora(model_data): msg = f"Requested adapter {adapter:s} is not a LoRA adapter" logger.error(msg) raise ValueError(msg) base_model = model_data.cardData["base_model"] is_list = isinstance(base_model, list) if (is_list and (self.model_id not in base_model)) or ( not is_list and self.model_id != base_model ): msg = f"Requested adapter {adapter:s} is not a LoRA adapter for base model {self.model_id:s}" logger.error(msg) raise ValueError(msg) logger.info( "LoRA adapter %s needs to be replaced with compatible adapter %s", self.current_lora_adapter, adapter, ) if self.current_lora_adapter is not None: self.ldm.unfuse_lora() self.ldm.unload_lora_weights() self._unload_textual_embeddings() self.current_lora_adapter = None logger.info("LoRA weights unloaded, loading new weights") weight_name = self._get_lora_weight_name(model_data=model_data) self.ldm.load_lora_weights( adapter, weight_name=weight_name, use_auth_token=self.use_auth_token ) self.current_lora_adapter = adapter self._fuse_or_raise() logger.info("LoRA weights loaded for adapter %s", adapter) self._load_textual_embeddings(adapter, model_data) else: logger.info("LoRA adapter %s already loaded", adapter) # Needed while a LoRA is loaded w/ model model_data = self._hub_model_info(adapter) if ( self._is_pivotal_tuning_lora(model_data) and self.current_tokens_loaded == 0 ): self._load_textual_embeddings(adapter, model_data) elif self.current_lora_adapter is not None: logger.info( "No LoRA adapter requested, unloading weights and using base model %s", self.model_id, ) self.ldm.unfuse_lora() self.ldm.unload_lora_weights() self._unload_textual_embeddings() self.current_lora_adapter = None