def step()

in src/pixparse/task/task_cruller_eval_rvlcdip.py [0:0]


    def step(self, sample):
        """
        Does one step of evaluation for classification on RVLCDIP.
        """
        metrics = {}
        metrics["classification"] = dict()
        correct_samples = 0
        ground_truths = [self.int2str[int(gt)] for gt in sample["label"]]
        already_counted = [False] * len(
            ground_truths
        )  

        with torch.inference_mode():
            tensor_images = torch.stack([im for im in sample["image"]]).to(
                self.device_env.device
            )
            output = self.model.image_encoder(tensor_images)

            current_strings = ["<s_rvlcdip>" for _ in range(tensor_images.shape[0])]

            input_ids = (
                torch.tensor(self.tokenizer.trunk.encode("<s_rvlcdip>")[1])
                .unsqueeze(0)
                .repeat(tensor_images.shape[0], 1)
                .to(self.device_env.device)
            )

            max_steps = 5  # Few steps for RVL CDIP, we have to predict at most 3 tokens

            for step in range(max_steps):
                inputs = self.prepare_inputs_for_inference(
                    input_ids=input_ids,
                    encoder_outputs=output,
                )
                decoder_outputs = self.model.text_decoder(**inputs)
                probabilities = F.softmax(decoder_outputs["logits"], dim=-1)
                next_token_ids = torch.argmax(probabilities, dim=-1)

                for idx in range(next_token_ids.shape[0]):
                    next_token_id = next_token_ids[
                        idx, -1
                    ].item()
                    next_token = self.tokenizer.trunk.decode([next_token_id])

                    current_strings[idx] += next_token

                    if next_token == "</s>":
                        generated_label = (
                            current_strings[idx]
                            .replace("<s_rvlcdip>", "")
                            .replace("</s>", "")
                            .replace("<s>", "")
                            .strip()
                        )
                        ground_truth_label = "<" + ground_truths[idx] + "/>"
                        if (
                            generated_label == ground_truth_label
                            and not already_counted[idx]
                        ):
                            correct_samples += 1
                            already_counted[idx] = True

                input_ids = torch.tensor(
                    [self.tokenizer.trunk.encode(s)[1:] for s in current_strings]
                ).to(self.device_env.device)

        # TODO Add other metrics relevant for eval step
        #
        # metrics['classification'] = ...
        metrics["classification"]["correct_samples"] = correct_samples
        metrics["classification"]["n_valid_samples"] = len(sample['label'])
        return metrics