threestudio/models/prompt_processors/clip_prompt_processor.py (36 lines of code) (raw):
import json
import os
from dataclasses import dataclass
import clip
import torch
import torch
import torch.nn as nn
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessor, hash_prompt
from threestudio.utils.misc import cleanup
from threestudio.utils.typing import *
@threestudio.register("clip-prompt-processor")
class ClipPromptProcessor(PromptProcessor):
@dataclass
class Config(PromptProcessor.Config):
pass
cfg: Config
@staticmethod
def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
clip_model, _ = clip.load(pretrained_model_name_or_path, jit=False)
with torch.no_grad():
tokens = clip.tokenize(
prompts,
).to(device)
text_embeddings = clip_model.encode_text(tokens)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
for prompt, embedding in zip(prompts, text_embeddings):
torch.save(
embedding,
os.path.join(
cache_dir,
f"{hash_prompt(pretrained_model_name_or_path, prompt)}.pt",
),
)
del clip_model