threestudio/models/prompt_processors/stable_diffusion_prompt_processor.py (103 lines of code) (raw):
import json
import os
from dataclasses import dataclass
import torch
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt
from threestudio.utils.misc import cleanup
from threestudio.utils.typing import *
@threestudio.register("stable-diffusion-prompt-processor")
class StableDiffusionPromptProcessor(PromptProcessor):
@dataclass
class Config(PromptProcessor.Config):
pass
cfg: Config
### these functions are unused, kept for debugging ###
def configure_text_encoder(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="tokenizer"
)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.text_encoder = CLIPTextModel.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="text_encoder"
).to(self.device)
for p in self.text_encoder.parameters():
p.requires_grad_(False)
def destroy_text_encoder(self) -> None:
del self.tokenizer
del self.text_encoder
cleanup()
def get_text_embeddings(
self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]]
) -> Tuple[Float[Tensor, "B 77 768"], Float[Tensor, "B 77 768"]]:
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
# Tokenize text and get embeddings
tokens = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
uncond_tokens = self.tokenizer(
negative_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
with torch.no_grad():
text_embeddings = self.text_encoder(tokens.input_ids.to(self.device))[0]
uncond_text_embeddings = self.text_encoder(
uncond_tokens.input_ids.to(self.device)
)[0]
return text_embeddings, uncond_text_embeddings
###
@staticmethod
def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
device_map="auto",
)
with torch.no_grad():
tokens = tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt",
)
text_embeddings = text_encoder(tokens.input_ids.to(text_encoder.device))[0]
for prompt, embedding in zip(prompts, text_embeddings):
torch.save(
embedding,
os.path.join(
cache_dir,
f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt",
),
)
del text_encoder
from transformers.models.clip import CLIPTextModel, CLIPTokenizer
def add_tokens_to_model(learned_embeds_path, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, override_token: Optional[Union[str, dict]] = None) -> None:
r"""Adds tokens to the tokenizer and text encoder of a model."""
learned_embeds = torch.load(learned_embeds_path, map_location='cpu')
# Loop over learned embeddings
new_tokens = []
for token, embedding in learned_embeds.items():
embedding = embedding.to(text_encoder.get_input_embeddings().weight.dtype)
if override_token is not None:
token = override_token if isinstance(override_token, str) else override_token[token]
# Add the token to the tokenizer
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError((f"The tokenizer already contains the token {token}. Please pass a "
"different `token` that is not already in the tokenizer."))
# Resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# Get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
new_tokens.append(token)
print(f'Added {len(new_tokens)} tokens to tokenizer and text embedding: {new_tokens}')