def prepare_dataset()

in vision/m4/models/vgpt2/evaluation_captioning_in_context_vgpt2.py [0:0]


    def prepare_dataset(self, exs: Dict, **kwargs) -> Dict:
        """
        Prepare batch of examples.
        """
        support_dataset: Dataset = kwargs["support_dataset"]
        support_dataset_vision_encoder_embeddings: Optional[np.ndarray] = kwargs.get(
            "support_dataset_vision_encoder_embeddings", None
        )
        num_shots: int = kwargs["num_shots"]
        shot_selection_mode: ShotSelectionMode = kwargs["shot_selection_mode"]
        prompt_template_id: int = kwargs["prompt_template_id"]

        nb_exs = len(exs["id"])

        def retrieve_idx_closest_examples(ref_embedding, embeddings_to_compare, num_examples):
            "Returns the indices of the `num_examples` closest embeddings in ascending order"
            sim = np.dot(embeddings_to_compare, ref_embedding)
            # We can achieve linear complexity because we don't need to sort all the numbers,
            # but only find the `num_examples` largest ones
            idx_closest_ex = np.argpartition(sim, -num_examples)[-num_examples:]
            idx_closest_ex = idx_closest_ex[np.argsort(sim[idx_closest_ex])].tolist()
            return idx_closest_ex

        if (shot_selection_mode == ShotSelectionMode.random) or (num_shots == 0):
            idx_shots = [random.sample(range(len(support_dataset)), num_shots) for _ in range(nb_exs)]
        elif shot_selection_mode == ShotSelectionMode.first_without_image:
            idx_shots = [list(range(num_shots)) for _ in range(nb_exs)]
        else:
            idx_shots = [
                retrieve_idx_closest_examples(ref_embedding, support_dataset_vision_encoder_embeddings, num_shots)
                for ref_embedding in exs["vision_encoder_embeddings"]
            ]

        # Prepare text shots
        # These are the priming text shots - size: batch_size
        texts_shots = [
            "".join(
                [
                    self._create_example_prompt(
                        prompt_template_id=prompt_template_id,
                        caption=random.choice(support_dataset[idx_shot][self.reference_captions_column_name]),
                        image=support_dataset[idx_shot][self.image_column_name],
                        context=(
                            support_dataset[idx_shot][self.context_column_name] if self.context_column_name else None
                        ),
                        without_image=shot_selection_mode == ShotSelectionMode.first_without_image,
                        eos_token=self.tokenizer.eos_token,
                    )
                    for idx_shot in idx_shots_ex
                ]
            )
            for idx_shots_ex in idx_shots
        ]

        # These are the tested example - size: batch_size
        tested_exs = [
            self._create_example_prompt(
                prompt_template_id=prompt_template_id,
                image=exs[self.image_column_name][idx],
                context=exs[self.context_column_name][idx] if self.context_column_name else None,
                eos_token="",
            )
            for idx in range(nb_exs)
        ]
        if self.bool_instruct_templates:
            tested_exs = [ex[: -len("<end_of_utterance>\n")].strip() for ex in tested_exs]

        # These are the concatenation of the priming text shots and tested example - size: batch_siz
        tot_texts = [
            self._create_prefix_prompt(prompt_template_id=prompt_template_id) + text_shot + tested_ex
            for text_shot, tested_ex in zip(texts_shots, tested_exs)
        ]

        tot_texts = [text.strip() for text in tot_texts]
        # Tokenize and masks
        tokens = self.tokenizer(
            tot_texts,
            return_tensors="pt",
            truncation=True,
            max_length=self.tokenizer_max_seq_len,
            padding=True,
            add_special_tokens=False,
        )
        input_ids = [tokens.input_ids[idx] for idx in range(len(tot_texts))]
        attention_mask = [tokens.attention_mask[idx] for idx in range(len(tot_texts))]

        # Prepare image shots
        # These are the priming image shots - size: batch_size
        if shot_selection_mode == ShotSelectionMode.first_without_image:
            pixel_values_shots = [[] for _ in range(nb_exs)]
        else:
            pixel_values_shots = [
                [
                    self.image_transform(sub_image)
                    for idx_shot in idx_shots_ex
                    for sub_image in self.simpler_get_splitted_images_and_corresponding_text(
                        image=support_dataset[idx_shot][self.image_column_name],
                    )[0]
                ]
                for idx_shots_ex in idx_shots
            ]

        # These are the tested images - size: batch_size
        tested_pixel_values = [
            [
                self.image_transform(sub_image)
                for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=image)[0]
            ]
            for image in exs[self.image_column_name]
        ]

        # These are the concatenation of the priming image shots and tested images - size: batch_size
        pixel_values = []
        pixel_attention_masks = []
        for pv_shots, pv in zip(pixel_values_shots, tested_pixel_values):
            num_images = len(pv_shots) + len(pv)
            max_height = max([im.size(1) for im in pv_shots] + [im.size(1) for im in pv])
            max_width = max([im.size(2) for im in pv_shots] + [im.size(2) for im in pv])
            padded_image_tensor = torch.zeros(num_images, 3, max_height, max_width)
            padded_pixel_attention_masks = torch.zeros(num_images, max_height, max_width, dtype=torch.bool)
            for idx, im in enumerate(pv_shots + pv):
                im_height, im_width = im.size(1), im.size(2)
                padded_image_tensor[idx, :, :im_height, :im_width] = im
                padded_pixel_attention_masks[idx, :im_height, :im_width] = True
            pixel_values.append(padded_image_tensor)
            pixel_attention_masks.append(padded_pixel_attention_masks)

        example_ids: List[int] = exs["id"]
        reference_captions = exs[self.reference_captions_column_name]
        if isinstance(reference_captions[0], str):
            reference_captions = [[ref_cap] for ref_cap in reference_captions]
        return {
            "example_ids": example_ids,
            "reference_captions": reference_captions,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "pixel_attention_masks": pixel_attention_masks,
        }