def step()

in src/pixparse/task/task_cruller_eval_docvqa.py [0:0]


    def step(self, batch):
        """
        Does one step of evaluation for DOCVQA. 
        Current limitation: sample-by-sample decoding.
        """
        metrics = {}
        image_outputs = self.model.image_encoder(batch['images'].to(self.device_env.device))
        for output, question, answers, question_id in zip(image_outputs, batch['questions'], batch['ground_truth_answers'], batch['question_ids']):
            self.all_ground_truths.append(answers)
            with torch.inference_mode():           
                # split out answer from prompt
                current_string = self.task_start_token + "<s_question>" + question + "</s_question>" + "<s_answer>" 
                input_ids = torch.tensor(self.tokenizer.trunk.encode(current_string, add_special_tokens=False)).unsqueeze(0).to(self.device_env.device)  # Adding extra dimension for batch
                max_steps = 512  # maximum number of steps

                for step in range(max_steps):
                    inputs = self.prepare_inputs_for_inference(input_ids=input_ids, encoder_outputs=output)
                    
                    decoder_outputs = self.model.text_decoder(**inputs)
                    
                    probabilities = F.softmax(decoder_outputs['logits'], dim=-1)
                    next_token_id = torch.argmax(probabilities[0, -1]).item()  # Just get the last token for the single sample
                    
                    next_token = self.tokenizer.trunk.decode([next_token_id])
                    current_string += next_token

                    if next_token == "</s>":
                        break

                    input_ids = torch.tensor(self.tokenizer.trunk.encode(current_string, add_special_tokens=False)).unsqueeze(0).to(self.device_env.device)

                predicted_json = token2json(current_string)
            if 'answer' in predicted_json:
                self.all_predictions.append(predicted_json['answer'])
            else:
                self.all_predictions.append("")
        return metrics