in src/pixparse/task/task_cruller_eval_docvqa.py [0:0]
def step(self, batch):
"""
Does one step of evaluation for DOCVQA.
Current limitation: sample-by-sample decoding.
"""
metrics = {}
image_outputs = self.model.image_encoder(batch['images'].to(self.device_env.device))
for output, question, answers, question_id in zip(image_outputs, batch['questions'], batch['ground_truth_answers'], batch['question_ids']):
self.all_ground_truths.append(answers)
with torch.inference_mode():
# split out answer from prompt
current_string = self.task_start_token + "<s_question>" + question + "</s_question>" + "<s_answer>"
input_ids = torch.tensor(self.tokenizer.trunk.encode(current_string, add_special_tokens=False)).unsqueeze(0).to(self.device_env.device) # Adding extra dimension for batch
max_steps = 512 # maximum number of steps
for step in range(max_steps):
inputs = self.prepare_inputs_for_inference(input_ids=input_ids, encoder_outputs=output)
decoder_outputs = self.model.text_decoder(**inputs)
probabilities = F.softmax(decoder_outputs['logits'], dim=-1)
next_token_id = torch.argmax(probabilities[0, -1]).item() # Just get the last token for the single sample
next_token = self.tokenizer.trunk.decode([next_token_id])
current_string += next_token
if next_token == "</s>":
break
input_ids = torch.tensor(self.tokenizer.trunk.encode(current_string, add_special_tokens=False)).unsqueeze(0).to(self.device_env.device)
predicted_json = token2json(current_string)
if 'answer' in predicted_json:
self.all_predictions.append(predicted_json['answer'])
else:
self.all_predictions.append("")
return metrics