in src/pixparse/task/task_cruller_eval_cord.py [0:0]
def step(self, batch):
"""
Does one step of evaluation for classification on CORD.
Current limitation: sample-by-sample decoding.
"""
metrics = {}
for image, label in zip(batch["image"], batch['label']):
decoded_gt = self.tokenizer.trunk.decode(label)
ground_truth = token2json(decoded_gt)
with torch.inference_mode():
tensor_image = image.unsqueeze(0).to(self.device_env.device) # Adding an extra dimension for batch
output = self.model.image_encoder(tensor_image)
current_string = "<s_cord>"
input_ids = torch.tensor(self.tokenizer.trunk.encode("<s_cord>", add_special_tokens=False)).unsqueeze(0).to(self.device_env.device) # Adding extra dimension for batch
max_steps = 512 # maximum number of steps
for step in range(max_steps):
inputs = self.prepare_inputs_for_inference(input_ids=input_ids, encoder_outputs=output)
decoder_outputs = self.model.text_decoder(**inputs)
probabilities = F.softmax(decoder_outputs['logits'], dim=-1)
next_token_id = torch.argmax(probabilities[0, -1]).item() # Just get the last token for the single sample
next_token = self.tokenizer.trunk.decode([next_token_id])
current_string += next_token
if next_token == "</s>":
break
input_ids = torch.tensor(self.tokenizer.trunk.encode(current_string, add_special_tokens=False)).unsqueeze(0).to(self.device_env.device)
predicted_json = token2json(current_string)
self.all_predictions.append(predicted_json)
self.all_ground_truths.append(ground_truth)
acc = self.evaluator.cal_acc(predicted_json, ground_truth)
self.acc_list.append(acc)
metrics["batch_accuracy"] = acc
return metrics