threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py (103 lines of code) (raw):

import json import os from dataclasses import dataclass import torch import torch.nn as nn from transformers import AutoTokenizer, CLIPTextModel import threestudio from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt from threestudio.utils.misc import cleanup from threestudio.utils.typing import * @threestudio.register("stable-diffusion-prompt-processor") class StableDiffusionPromptProcessor(PromptProcessor): @dataclass class Config(PromptProcessor.Config): pass cfg: Config ### these functions are unused, kept for debugging ### def configure_text_encoder(self) -> None: self.tokenizer = AutoTokenizer.from_pretrained( self.cfg.pretrained_model_name_or_path, subfolder="tokenizer" ) os.environ["TOKENIZERS_PARALLELISM"] = "false" self.text_encoder = CLIPTextModel.from_pretrained( self.cfg.pretrained_model_name_or_path, subfolder="text_encoder" ).to(self.device) for p in self.text_encoder.parameters(): p.requires_grad_(False) def destroy_text_encoder(self) -> None: del self.tokenizer del self.text_encoder cleanup() def get_text_embeddings( self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] ) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]: if isinstance(prompt, str): prompt = [prompt] if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] # Tokenize text and get embeddings tokens = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) uncond_tokens = self.tokenizer( negative_prompt, padding="max_length", max_length=self.tokenizer.model_max_length, return_tensors="pt", ) with torch.no_grad(): text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0] uncond_text_embeddings = self.text_encoder( uncond_tokens.input_ids.to(self.device) )[0] return text_embeddings, uncond_text_embeddings ### @staticmethod def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device): os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", device_map="auto", ) with torch.no_grad(): tokens = tokenizer( prompts, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ) text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0] for prompt, embedding in zip(prompts, text_embeddings): torch.save( embedding, os.path.join( cache_dir, f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt", ), ) del text_encoder from transformers.models.clip import CLIPTextModel, CLIPTokenizer def add_tokens_to_model(learned_embeds_path, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, override_token: Optional[Union[str, dict]] = None) -> None: r"""Adds tokens to the tokenizer and text encoder of a model.""" learned_embeds = torch.load(learned_embeds_path, map_location='cpu') # Loop over learned embeddings new_tokens = [] for token, embedding in learned_embeds.items(): embedding = embedding.to(text_encoder.get_input_embeddings().weight.dtype) if override_token is not None: token = override_token if isinstance(override_token, str) else override_token[token] # Add the token to the tokenizer num_added_tokens = tokenizer.add_tokens(token) if num_added_tokens == 0: raise ValueError((f"The tokenizer already contains the token {token}. Please pass a " "different `token` that is not already in the tokenizer.")) # Resize the token embeddings text_encoder.resize_token_embeddings(len(tokenizer)) # Get the id for the token and assign the embeds token_id = tokenizer.convert_tokens_to_ids(token) text_encoder.get_input_embeddings().weight.data[token_id] = embedding new_tokens.append(token) print(f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}')