in threestudio/models/prompt_processors/base.py [0:0]
def get_debiased_prompt(self, prompt: str) -> List[str]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.pretrained_model_name_or_path_prompt_debiasing
)
model = BertForMaskedLM.from_pretrained(
self.cfg.pretrained_model_name_or_path_prompt_debiasing
)
views = [d.name for d in self.directions]
view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0]
view_ids = view_ids[1:5]
def modulate(prompt):
prompt_vd = f"This image is depicting a [MASK] view of {prompt}"
tokens = tokenizer(
prompt_vd,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1]
logits = model(**tokens).logits
logits = F.softmax(logits[0, mask_idx], dim=-1)
logits = logits[0, view_ids]
probes = logits / logits.sum()
return probes
prompts = [prompt.split(" ") for _ in range(4)]
full_probe = modulate(prompt)
n_words = len(prompt.split(" "))
prompt_debiasing_mask_ids = (
self.cfg.prompt_debiasing_mask_ids
if self.cfg.prompt_debiasing_mask_ids is not None
else list(range(n_words))
)
words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids]
threestudio.info(f"Words that can potentially be removed: {words_to_debias}")
for idx in prompt_debiasing_mask_ids:
words = prompt.split(" ")
prompt_ = " ".join(words[:idx] + words[(idx + 1) :])
part_probe = modulate(prompt_)
pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5)
for i in range(pmi.shape[0]):
if pmi[i].item() < 0.95:
prompts[i][idx] = ""
debiased_prompts = [" ".join([word for word in p if word]) for p in prompts]
for d, debiased_prompt in zip(views, debiased_prompts):
threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]")
del tokenizer, model
cleanup()
return debiased_prompts