in lmms_eval/models/llava.py [0:0]
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
if type(doc_to_target) == str:
continuation = doc_to_target
else:
continuation = doc_to_target(self.task_dict[task][split][doc_id])
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
if visuals:
image = process_images(visuals, self._image_processor, self._config)
if type(image) is list:
image = [_image.to(dtype=torch.float16, device=self.device) for _image in image]
else:
image = image.to(dtype=torch.float16, device=self.device)
else:
image = None
prompts_input = contexts[0]
if image is not None and len(image) != 0 and DEFAULT_IMAGE_TOKEN not in prompts_input:
"""
Three senarios:
1. No image, and there for, no image token should be added.
2. image token is already specified in the context, so we don't need to add it.
3. image token is not specified in the context and there is image inputs, so we need to add it. In this case, we add the image token at the beginning of the context and add a new line.
"""
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visuals)
image_tokens = " ".join(image_tokens)
prompts_input = image_tokens + "\n" + contexts[0]
conv = conv_templates[self.conv_template].copy()
conv.append_message(conv.roles[0], prompts_input)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
contxt_id = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
# Add the answer of the second role
conv.messages[1][1] = continuation
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
labels = input_ids.clone()
# Context part no need to calculate for loss
labels[0, : contxt_id.shape[1]] = -100
with torch.inference_mode():
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True)
loss = outputs["loss"]
# loss = torch.exp(loss)
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = input_ids[:, contxt_id.shape[1] :] # [1, seq]
greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : input_ids.shape[1]] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)
pbar.close()
return res