def step()

in src/pixparse/task/task_cruller_eval_ocr.py [0:0]


    def step(self, sample):
        """
        Does one step of evaluation for OCR.
        """
        metrics = {}
        image_input, text_input, text_target = sample
        text_input = [item[0] for item in text_input]
        text_input = torch.stack(text_input, dim=0).to(
            self.device_env.device, non_blocking=True
        )
        text_target = [item[0] for item in text_target]
        text_target = torch.stack(text_target, dim=0).to(
            self.device_env.device, non_blocking=True
        )
        image_input = image_input.to(self.device_env.device, non_blocking=True)

        # Add OCR-related metrics and generation

        ocr_metrics, _ = get_ocr_metrics(
            model=self.model,
            tokenizer=self.tokenizer,
            image_input=image_input,
            text_input=text_target,
            device_env=self.device_env,
            max_recursion_length=self.max_recursion_length,
            prompt_token=self.task_start_token,
        )

        metrics["ocr_reconstruction"] = ocr_metrics

        # TODO Add other metrics relevant for eval step
        #
        # metrics['metric_category'] = ...
        return metrics