in lmms_eval/models/fuyu.py [0:0]
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
if type(doc_to_target) == str:
continuation = doc_to_target
else:
continuation = doc_to_target(self.task_dict[task][split][doc_id])
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
formatted_contexts = [f"{contexts}\n"]
formatted_continuation = [f"{contexts}\n{continuation}"]
model_inputs = self.processor(text=formatted_continuation, images=visuals, device=self.device)
for k, v in model_inputs.items():
model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v]
for index in range(len(model_inputs["image_patches"])):
model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype)
labels = model_inputs["input_ids"].clone()
contxt_id = self.processor(text=formatted_contexts, return_tensors="pt")["input_ids"]
labels[: len(contxt_id)] = -100
with torch.inference_mode():
outputs = self.model(**model_inputs, labels=labels)
loss = outputs["loss"]
# loss = torch.exp(loss)
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = model_inputs["input_ids"][:, contxt_id.shape[1] :] # [1, seq]
greedy_tokens = greedy_tokens[:, contxt_id.shape[1] : model_inputs["input_ids"].shape[1]] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)
pbar.close()
return res