def step()

in src/pixparse/task/task_donut_eval_ocr.py [0:0]


    def step(self, sample):
        """
        Does one step of evaluation for OCR.
        """
        metrics = {}
        image_input, text_input, text_target = sample
        text_input = [item[0] for item in text_input]
        text_input = torch.stack(text_input, dim=0).to(
            self.device_env.device, non_blocking=True
        )
        text_target = [item[0] for item in text_target]
        text_target = torch.stack(text_target, dim=0).to(
            self.device_env.device, non_blocking=True
        )

        # Compute OCR metrics for Donut

        decoder_input_ids = self.processor.tokenizer(self.task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

        pixel_values = self.processor([im.convert('RGB') for im in image_input], return_tensors="pt").pixel_values

        with torch.inference_mode():
            outputs = [self.model.generate(
                pixel_value.unsqueeze(0).to(self.device_env.device),
                decoder_input_ids=decoder_input_ids.to(self.device_env.device),
                max_length=self.max_position_embeddings,
                early_stopping=True,
                pad_token_id=self.processor.tokenizer.pad_token_id,
                eos_token_id=self.processor.tokenizer.eos_token_id,
                use_cache=True,
                num_beams=1,
                bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            ) for pixel_value in pixel_values]
        generated_text = [self.clean_text(self.processor.decode(greedy_outputs.sequences[0])) for greedy_outputs in outputs]
        text_input[
            text_input == -100
        ] = (
            self.processor.tokenizer.pad_token_id
        ) 
        raw_decoded_texts = self.processor.tokenizer.batch_decode(text_input)
        decoded_texts = [self.clean_text(t) for t in raw_decoded_texts]
        # FIXME sometimes we are decoding no text at all after cleaning
        filtered = [
            (ref, pred)
            for ref, pred in zip(decoded_texts, generated_text)
            if ref and pred
        ]

        if not filtered:
            return None, None

        decoded_texts, ocr_predictions = zip(*filtered)

        decoded_texts = list(decoded_texts)
        ocr_predictions = list(ocr_predictions)
        
        ocr_predictions = [
            text[0 : len(reftext)]
            for text, reftext in zip(ocr_predictions, decoded_texts)
        ]

        metrics["ocr_reconstruction"] = get_cer_wer_metrics(
            self.cer_transforms,
            self.wer_transforms,
            dict(),
            ocr_predictions,
            decoded_texts,
        )
        # TODO Add other metrics relevant for eval step
        #
        # metrics['metric_category'] = ...
        return metrics