in src/pixparse/task/task_donut_eval_ocr.py [0:0]
def step(self, sample):
"""
Does one step of evaluation for OCR.
"""
metrics = {}
image_input, text_input, text_target = sample
text_input = [item[0] for item in text_input]
text_input = torch.stack(text_input, dim=0).to(
self.device_env.device, non_blocking=True
)
text_target = [item[0] for item in text_target]
text_target = torch.stack(text_target, dim=0).to(
self.device_env.device, non_blocking=True
)
# Compute OCR metrics for Donut
decoder_input_ids = self.processor.tokenizer(self.task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = self.processor([im.convert('RGB') for im in image_input], return_tensors="pt").pixel_values
with torch.inference_mode():
outputs = [self.model.generate(
pixel_value.unsqueeze(0).to(self.device_env.device),
decoder_input_ids=decoder_input_ids.to(self.device_env.device),
max_length=self.max_position_embeddings,
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
) for pixel_value in pixel_values]
generated_text = [self.clean_text(self.processor.decode(greedy_outputs.sequences[0])) for greedy_outputs in outputs]
text_input[
text_input == -100
] = (
self.processor.tokenizer.pad_token_id
)
raw_decoded_texts = self.processor.tokenizer.batch_decode(text_input)
decoded_texts = [self.clean_text(t) for t in raw_decoded_texts]
# FIXME sometimes we are decoding no text at all after cleaning
filtered = [
(ref, pred)
for ref, pred in zip(decoded_texts, generated_text)
if ref and pred
]
if not filtered:
return None, None
decoded_texts, ocr_predictions = zip(*filtered)
decoded_texts = list(decoded_texts)
ocr_predictions = list(ocr_predictions)
ocr_predictions = [
text[0 : len(reftext)]
for text, reftext in zip(ocr_predictions, decoded_texts)
]
metrics["ocr_reconstruction"] = get_cer_wer_metrics(
self.cer_transforms,
self.wer_transforms,
dict(),
ocr_predictions,
decoded_texts,
)
# TODO Add other metrics relevant for eval step
#
# metrics['metric_category'] = ...
return metrics