def get_train_ocr_metrics()

in src/pixparse/task/task_cruller_pretrain.py [0:0]


    def get_train_ocr_metrics(self, sample):
        """
        In cruller_pretrain, this task returns some utils logs useful to monitor training.
        Typically we want to return a few samples of images 
        and their generated OCR so that we can log them onto a tensorboard gallery in
        the log_step
        """
        metrics = {}
        eval_data = {}
        image_input, text_input, text_target = sample

        image_input = image_input.to(self.device_env.device, non_blocking=True)
        text_input = text_input[:, :-1].to(self.device_env.device, non_blocking=True)
        text_target = text_target[:, 1:].to(self.device_env.device, non_blocking=True)

        """
        metrics = {}
        image_input, text_input, text_target = sample
        text_input = [item[0] for item in text_input]
        text_input = torch.stack(text_input, dim=0).to(self.device_env.device, non_blocking=True)
        text_target = [item[0] for item in text_target]
        text_target = torch.stack(text_target, dim=0).to(self.device_env.device, non_blocking=True)
        image_input = image_input.to(self.device_env.device, non_blocking=True)

        # Add OCR-related metrics and generation

        ocr_metrics, _ = get_ocr_metrics(
            model=self.model,
            tokenizer=self.tokenizer,
            image_input=image_input,
            text_input=text_target,
            device_env=self.device_env,
            max_recursion_length=self.max_recursion_length,
        )"""

        # Add OCR-related metrics and generation

        ocr_metrics, ocr_reconstructed_sample = get_ocr_metrics(
            model=self.model,
            tokenizer=self.tokenizer,
            image_input=image_input,
            text_input=text_target,
            device_env=self.device_env,
            max_recursion_length=self.max_recursion_length
            )
        if ocr_metrics and ocr_reconstructed_sample:
            metrics['ocr_reconstruction'] = ocr_metrics
            eval_data['ocr_reconstruction_data'] = ocr_reconstructed_sample
        else:
            _logger.info("Can't generate text from current batch. Skipping metrics...")
        
        # TODO Add other metrics relevant for eval step
        # 
        # metrics['metric_category'] = ... 
        return metrics, eval_data