in src/pixparse/task/task_cruller_eval_ocr.py [0:0]
def step(self, sample):
"""
Does one step of evaluation for OCR.
"""
metrics = {}
image_input, text_input, text_target = sample
text_input = [item[0] for item in text_input]
text_input = torch.stack(text_input, dim=0).to(
self.device_env.device, non_blocking=True
)
text_target = [item[0] for item in text_target]
text_target = torch.stack(text_target, dim=0).to(
self.device_env.device, non_blocking=True
)
image_input = image_input.to(self.device_env.device, non_blocking=True)
# Add OCR-related metrics and generation
ocr_metrics, _ = get_ocr_metrics(
model=self.model,
tokenizer=self.tokenizer,
image_input=image_input,
text_input=text_target,
device_env=self.device_env,
max_recursion_length=self.max_recursion_length,
prompt_token=self.task_start_token,
)
metrics["ocr_reconstruction"] = ocr_metrics
# TODO Add other metrics relevant for eval step
#
# metrics['metric_category'] = ...
return metrics