in lmms_eval/models/qwen_vl.py [0:0]
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
# encode, pad, and truncate contexts for this batch
if type(doc_to_target) == str:
continuation = doc_to_target
else:
continuation = doc_to_target(self.task_dict[task][split][doc_id])
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
query = []
visual_paths = []
for visual in visuals:
name = uuid.uuid4().hex.upper()[0:6]
visual.save(f"/tmp/{name}.png")
visual_paths.append(f"/tmp/{name}.png")
query.append({"image": f"/tmp/{name}.png"})
# Make a copy for query to save context (text that needs to be masked)
context_query = [_ for _ in query]
context_query.append({"text": contexts})
query.append({"text": contexts + continuation})
context_query = self.tokenizer.from_list_format(context_query)
query = self.tokenizer.from_list_format(query)
raw_contxt_text, context_tokens = make_context(
self.tokenizer, context_query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format
)
context_tokens = torch.tensor([context_tokens])
raw_continuation_text, continuation_tokens = make_context(
self.tokenizer, query, history=None, system="You are a helpful assistant", max_window_size=self.model.generation_config.max_window_size, chat_format=self.model.generation_config.chat_format
)
continuation_tokens = torch.tensor([continuation_tokens]).to(self.model.device)
attn_mask = torch.ones_like(continuation_tokens).to(self.model.device)
labels = continuation_tokens.clone().to(self.model.device)
labels[:, : context_tokens.shape[1]] = -100
with torch.inference_mode():
outputs = self.model(input_ids=continuation_tokens, labels=labels, attention_mask=attn_mask)
loss = outputs.loss
logits = outputs["logits"]
greedy_tokens = logits.argmax(dim=-1)
cont_toks = continuation_tokens[:, context_tokens.shape[1] :]
greedy_tokens = greedy_tokens[:, context_tokens.shape[1] : continuation_tokens.shape[1]] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
res.append((float(loss.item()), bool(max_equal)))
pbar.update(1)
pbar.close()
return res