in src/pixparse/task/task_cruller_eval_rvlcdip.py [0:0]
def step(self, sample):
"""
Does one step of evaluation for classification on RVLCDIP.
"""
metrics = {}
metrics["classification"] = dict()
correct_samples = 0
ground_truths = [self.int2str[int(gt)] for gt in sample["label"]]
already_counted = [False] * len(
ground_truths
)
with torch.inference_mode():
tensor_images = torch.stack([im for im in sample["image"]]).to(
self.device_env.device
)
output = self.model.image_encoder(tensor_images)
current_strings = ["<s_rvlcdip>" for _ in range(tensor_images.shape[0])]
input_ids = (
torch.tensor(self.tokenizer.trunk.encode("<s_rvlcdip>")[1])
.unsqueeze(0)
.repeat(tensor_images.shape[0], 1)
.to(self.device_env.device)
)
max_steps = 5 # Few steps for RVL CDIP, we have to predict at most 3 tokens
for step in range(max_steps):
inputs = self.prepare_inputs_for_inference(
input_ids=input_ids,
encoder_outputs=output,
)
decoder_outputs = self.model.text_decoder(**inputs)
probabilities = F.softmax(decoder_outputs["logits"], dim=-1)
next_token_ids = torch.argmax(probabilities, dim=-1)
for idx in range(next_token_ids.shape[0]):
next_token_id = next_token_ids[
idx, -1
].item()
next_token = self.tokenizer.trunk.decode([next_token_id])
current_strings[idx] += next_token
if next_token == "</s>":
generated_label = (
current_strings[idx]
.replace("<s_rvlcdip>", "")
.replace("</s>", "")
.replace("<s>", "")
.strip()
)
ground_truth_label = "<" + ground_truths[idx] + "/>"
if (
generated_label == ground_truth_label
and not already_counted[idx]
):
correct_samples += 1
already_counted[idx] = True
input_ids = torch.tensor(
[self.tokenizer.trunk.encode(s)[1:] for s in current_strings]
).to(self.device_env.device)
# TODO Add other metrics relevant for eval step
#
# metrics['classification'] = ...
metrics["classification"]["correct_samples"] = correct_samples
metrics["classification"]["n_valid_samples"] = len(sample['label'])
return metrics