in src/pixparse/task/task_cruller_pretrain.py [0:0]
def get_train_ocr_metrics(self, sample):
"""
In cruller_pretrain, this task returns some utils logs useful to monitor training.
Typically we want to return a few samples of images
and their generated OCR so that we can log them onto a tensorboard gallery in
the log_step
"""
metrics = {}
eval_data = {}
image_input, text_input, text_target = sample
image_input = image_input.to(self.device_env.device, non_blocking=True)
text_input = text_input[:, :-1].to(self.device_env.device, non_blocking=True)
text_target = text_target[:, 1:].to(self.device_env.device, non_blocking=True)
"""
metrics = {}
image_input, text_input, text_target = sample
text_input = [item[0] for item in text_input]
text_input = torch.stack(text_input, dim=0).to(self.device_env.device, non_blocking=True)
text_target = [item[0] for item in text_target]
text_target = torch.stack(text_target, dim=0).to(self.device_env.device, non_blocking=True)
image_input = image_input.to(self.device_env.device, non_blocking=True)
# Add OCR-related metrics and generation
ocr_metrics, _ = get_ocr_metrics(
model=self.model,
tokenizer=self.tokenizer,
image_input=image_input,
text_input=text_target,
device_env=self.device_env,
max_recursion_length=self.max_recursion_length,
)"""
# Add OCR-related metrics and generation
ocr_metrics, ocr_reconstructed_sample = get_ocr_metrics(
model=self.model,
tokenizer=self.tokenizer,
image_input=image_input,
text_input=text_target,
device_env=self.device_env,
max_recursion_length=self.max_recursion_length
)
if ocr_metrics and ocr_reconstructed_sample:
metrics['ocr_reconstruction'] = ocr_metrics
eval_data['ocr_reconstruction_data'] = ocr_reconstructed_sample
else:
_logger.info("Can't generate text from current batch. Skipping metrics...")
# TODO Add other metrics relevant for eval step
#
# metrics['metric_category'] = ...
return metrics, eval_data