def prepare_dataset()

in vision/m4/models/vgpt2/evaluation_image_caption_matching_vgpt2.py [0:0]


    def prepare_dataset(self, exs: Dict, **kwargs) -> Dict:
        """
        Prepare batch of examples.

        """

        prompt_template_id: int = kwargs["prompt_template_id"]

        nb_exs = len(exs["id"])
        nb_captions = len(self.caption_column_names)
        nb_images = len(self.image_column_names)

        # If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence [caption_0, caption_0, caption_1, caption_1]
        general_dict = {"tested_prompts": [], "caption_ids": [], "image_ids": [], "ex_ids": []}
        for idx_ex in range(nb_exs):
            for caption_idx, caption_column in enumerate(self.caption_column_names):
                for image_idx in range(nb_images):
                    tested_prompt = self._create_example_prompt(
                        prompt_template_id=prompt_template_id,
                        caption=exs[caption_column][idx_ex],
                    )
                    general_dict["tested_prompts"].append(tested_prompt)
                    general_dict["caption_ids"].append(caption_idx)
                    general_dict["image_ids"].append(image_idx)
                    general_dict["ex_ids"].append(exs["id"][idx_ex])

        tot_texts = [
            self._create_prefix_prompt(prompt_template_id=prompt_template_id) + tested_prompt
            for tested_prompt in general_dict["tested_prompts"]
        ]

        tot_texts = [text.strip() for text in tot_texts]
        # Tokenize and masks
        tokens = self.tokenizer(
            tot_texts,
            return_tensors="pt",
            truncation=True,
            max_length=self.tokenizer_max_seq_len,
            padding=True,
            add_special_tokens=False,
        )

        general_dict["input_ids"] = [tokens.input_ids[idx] for idx in range(len(tot_texts))]
        general_dict["attention_mask"] = [tokens.attention_mask[idx] for idx in range(len(tot_texts))]

        # If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence image_0, image_1, image_0, image_1
        pixel_values_dict = {"pixel_values": [], "caption_ids": [], "image_ids": [], "ex_ids": []}
        for idx_ex in range(nb_exs):
            for caption_idx in range(nb_captions):
                for image_idx, col in enumerate(self.image_column_names):
                    pixel_values_dict["pixel_values"].append(self.image_transform(exs[col][idx_ex]).unsqueeze(0))
                    pixel_values_dict["caption_ids"].append(caption_idx)
                    pixel_values_dict["image_ids"].append(image_idx)
                    pixel_values_dict["ex_ids"].append(exs["id"][idx_ex])

        # ---- Sanity check ----
        assert pixel_values_dict["ex_ids"] == general_dict["ex_ids"]
        nb_combinations = nb_captions * nb_images
        sample_pixel_captions_ids = pixel_values_dict["caption_ids"][:nb_combinations]
        sample_pixel_image_ids = pixel_values_dict["image_ids"][:nb_combinations]
        sample_general_captions_ids = general_dict["caption_ids"][:nb_combinations]
        sample_general_image_ids = general_dict["image_ids"][:nb_combinations]
        self.captions_images_order_per_ex
        for idx in range(nb_combinations):
            expected_caption_idx, expected_image_idx = self.captions_images_order_per_ex[idx]
            assert sample_pixel_captions_ids[idx] == expected_caption_idx
            assert sample_general_captions_ids[idx] == expected_caption_idx
            assert sample_pixel_image_ids[idx] == expected_image_idx
            assert sample_general_image_ids[idx] == expected_image_idx
        # ---- Sanity check ----

        general_dict["ex_ids"] = self._split_array(general_dict["ex_ids"], nb_exs)
        general_dict["caption_ids"] = self._split_array(general_dict["caption_ids"], nb_exs)
        general_dict["image_ids"] = self._split_array(general_dict["image_ids"], nb_exs)
        general_dict["input_ids"] = self._split_array(general_dict["input_ids"], nb_exs)
        pixel_values_dict["pixel_values"] = self._split_array(pixel_values_dict["pixel_values"], nb_exs)
        general_dict["attention_mask"] = self._split_array(general_dict["attention_mask"], nb_exs)

        return {
            "example_ids": general_dict["ex_ids"],
            "caption_ids": general_dict["caption_ids"],
            "image_ids": general_dict["image_ids"],
            "input_ids": general_dict["input_ids"],
            "attention_mask": general_dict["attention_mask"],
            "pixel_values": pixel_values_dict["pixel_values"],
        }