def prepare_dataset()

in vision/m4/models/vgpt2/evaluation_classification_in_context_vgpt2.py [0:0]


    def prepare_dataset(self, exs: Dict, **kwargs) -> Dict:
        """
        Prepare batch of examples.
        Each example (X, y) where y is among (y1, y2, ..., yN) - the labels options -
        is turned into [(X, y1), (X, y2), ... (X, yN)].
        """
        support_dataset: Dataset = kwargs["support_dataset"]
        support_dataset_vision_encoder_embeddings: Optional[np.ndarray] = kwargs.get(
            "support_dataset_vision_encoder_embeddings", None
        )
        num_shots: int = kwargs["num_shots"]
        shot_selection_mode: ShotSelectionMode = kwargs["shot_selection_mode"]
        prompt_template_id: int = kwargs["prompt_template_id"]

        # Apply mapping from class names to prompt names
        prompted_class_names = sorted(
            set([self._get_class_name_value(prompt_template_id, class_name) for class_name in self.class_names])
        )
        class_prompt2int = {
            self._get_class_name_value(prompt_template_id, class_name): self.class_str2int(class_name)
            for class_name in self.class_names
        }
        class_int2prompt = {
            self.class_str2int(class_name): self._get_class_name_value(prompt_template_id, class_name)
            for class_name in self.class_names
        }

        nb_exs = len(exs["id"])
        # If the first image column is a list. We use it as the only image column, and <image> tokens are hardcoded in the dataset
        multiple_images_in_single_column = isinstance(support_dataset[0][self.image_column_names[0]], list)
        if multiple_images_in_single_column and len(self.image_column_names) > 1:
            raise ValueError(
                "We can either have multiple image columns, or multiple images in one column but not both"
            )

        if not self.tested_labels_column_name:
            nb_tested_labels_per_ex = len(prompted_class_names)
            tested_labels_exs = [[class_name for _ in range(nb_exs)] for class_name in prompted_class_names]
            tested_labels: List[int] = [
                class_prompt2int[label] for _tested_label in tested_labels_exs for label in _tested_label
            ]
        else:
            nb_tested_labels_per_ex = len(exs[self.tested_labels_column_name][0])
            tested_labels_exs = [
                [exs[self.tested_labels_column_name][idx_ex][idx_class] for idx_ex in range(nb_exs)]
                for idx_class in range(nb_tested_labels_per_ex)
            ]
            tested_labels: List[int] = [
                class_prompt2int[exs[self.tested_labels_column_name][idx_ex][idx_class]]
                for idx_class in range(nb_tested_labels_per_ex)
                for idx_ex in range(nb_exs)
            ]

        if self.relevance_scores_column_name:
            relevance_scores = [
                exs[self.relevance_scores_column_name][idx_ex][idx_class]
                for idx_class in range(nb_tested_labels_per_ex)
                for idx_ex in range(nb_exs)
            ]
        else:
            # Fake variable to match the common signature
            relevance_scores = [0.0] * nb_exs * nb_tested_labels_per_ex

        def retrieve_idx_closest_examples(ref_embedding, embeddings_to_compare, num_examples):
            "Returns the indices of the `num_examples` closest embeddings in ascending order"
            sim = np.dot(embeddings_to_compare, ref_embedding)
            # We can achieve linear complexity because we don't need to sort all the numbers,
            # but only find the `num_examples` largest ones
            idx_closest_ex = np.argpartition(sim, -num_examples)[-num_examples:]
            idx_closest_ex = idx_closest_ex[np.argsort(sim[idx_closest_ex])].tolist()
            return idx_closest_ex

        if (shot_selection_mode == ShotSelectionMode.random) or (num_shots == 0):
            idx_shots = [random.sample(range(len(support_dataset)), num_shots) for _ in range(nb_exs)]
        else:
            idx_shots = [
                retrieve_idx_closest_examples(ref_embedding, support_dataset_vision_encoder_embeddings, num_shots)
                for ref_embedding in exs["vision_encoder_embeddings"]
            ]
        # Prepare text shots
        texts_shots = [
            "".join(
                [
                    self._create_example_prompt(
                        prompt_template_id=prompt_template_id,
                        class_name=class_int2prompt[support_dataset[idx_shot][self.label_column_name]],
                        images=(
                            support_dataset[idx_shot][self.image_column_names[0]]
                            if multiple_images_in_single_column
                            else [
                                support_dataset[idx_shot][image_column_name]
                                for image_column_name in self.image_column_names
                            ]
                        ),
                        multiple_images_in_single_column=multiple_images_in_single_column,
                        contexts=(
                            [
                                (context_column_name, support_dataset[context_column_name][idx_shot])
                                for context_column_name in self.context_column_names
                            ]
                            if self.context_column_names
                            else None
                        ),
                        excluded_context_columns=[],
                    )
                    for idx_shot in idx_shots_ex
                ]
            )
            for idx_shots_ex in idx_shots
        ]
        texts_shots = (
            texts_shots * nb_tested_labels_per_ex
        )  # These are the priming text shots - size: batch_size * nb_of_labels
        tested_label_prompts = [
            self._create_example_prompt(
                prompt_template_id=prompt_template_id,
                class_name=tested_labels_exs[idx_class][idx_ex],
                images=(
                    exs[self.image_column_names[0]][idx_ex]
                    if multiple_images_in_single_column
                    else [exs[image_column_name][idx_ex] for image_column_name in self.image_column_names]
                ),
                multiple_images_in_single_column=multiple_images_in_single_column,
                contexts=(
                    [
                        (context_column_name, exs[context_column_name][idx_ex])
                        for context_column_name in self.context_column_names
                    ]
                    if self.context_column_names
                    else None
                ),
                excluded_context_columns=(
                    self.tested_ex_excluded_context_columns if self.tested_ex_excluded_context_columns else []
                ),
            )
            for idx_class in range(nb_tested_labels_per_ex)
            for idx_ex in range(nb_exs)
        ]  # These are the tested labels - size: batch_size * nb_of_labels
        tot_texts = [
            self._create_prefix_prompt(prompt_template_id=prompt_template_id) + text_shot + tested_label_prompt
            for text_shot, tested_label_prompt in zip(texts_shots, tested_label_prompts)
        ]  # These are the concatenation of the priming text shots and tested labels - size: batch_size * nb_of_labels
        # Ignoring their associated priming shots, the list has the following order: [x1,A; x2,A; ... xN,A; x1,B; x2,B; ...]

        tot_texts = [text.strip() for text in tot_texts]

        # Tokenize and masks
        tokens = self.tokenizer(
            tot_texts,
            return_tensors="pt",
            truncation=True,
            max_length=self.tokenizer_max_seq_len,
            padding=True,
            add_special_tokens=False,
        )
        input_ids = [tokens.input_ids[idx] for idx in range(len(tot_texts))]
        attention_mask = [tokens.attention_mask[idx] for idx in range(len(tot_texts))]
        if multiple_images_in_single_column:
            pixel_values_shots = [
                [
                    self.image_transform(sub_image)
                    for idx_shot in idx_shots_ex
                    for img in support_dataset[idx_shot][self.image_column_names[0]]
                    for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=img)[0]
                ]
                for idx_shots_ex in idx_shots
            ]
        else:
            pixel_values_shots = [
                [
                    self.image_transform(sub_image)
                    for idx_shot in idx_shots_ex
                    for image_column_name in self.image_column_names
                    for sub_image in self.simpler_get_splitted_images_and_corresponding_text(
                        image=support_dataset[idx_shot][image_column_name],
                    )[0]
                ]
                for idx_shots_ex in idx_shots
            ]

        # These are the tested images - size: batch_size
        if multiple_images_in_single_column:
            tested_pixel_values = [
                [
                    self.image_transform(sub_image)
                    for image in images
                    for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=image)[0]
                ]
                for images in exs[self.image_column_names[0]]
            ]

        else:
            tested_pixel_values = [
                [
                    self.image_transform(sub_image)
                    for col in self.image_column_names
                    for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=exs[col][i])[0]
                ]
                for i in range(len(exs["id"]))
            ]

        pixel_values = []
        pixel_attention_masks = []
        for pv_shots, pv in zip(pixel_values_shots, tested_pixel_values):
            num_images = len(pv_shots) + len(pv)
            max_height = max([im.size(1) for im in pv_shots] + [im.size(1) for im in pv])
            max_width = max([im.size(2) for im in pv_shots] + [im.size(2) for im in pv])
            padded_image_tensor = torch.zeros(num_images, 3, max_height, max_width)
            padded_pixel_attention_masks = torch.zeros(num_images, max_height, max_width, dtype=torch.bool)

            for idx, im in enumerate(pv_shots + pv):
                im_height, im_width = im.size(1), im.size(2)
                padded_image_tensor[idx, :, :im_height, :im_width] = im
                padded_pixel_attention_masks[idx, :im_height, :im_width] = True

            pixel_values.append(padded_image_tensor)
            pixel_attention_masks.append(padded_pixel_attention_masks)

        pixel_values = pixel_values * nb_tested_labels_per_ex  # size: batch_size * nb_of_labels
        pixel_attention_masks = pixel_attention_masks * nb_tested_labels_per_ex

        example_ids: List[int] = exs["id"] * nb_tested_labels_per_ex

        true_labels = exs[self.label_column_name]
        # Handle the case where the true labels are not provided
        labels_are_none = all(label is None for label in true_labels)
        if not labels_are_none:
            true_labels = [
                class_prompt2int[self._get_class_name_value(prompt_template_id, self.class_int2str(label_id))]
                for label_id in true_labels
            ]
            true_labels: List[int] = true_labels * nb_tested_labels_per_ex
        else:
            true_labels: List[int] = [-1] * len(true_labels) * nb_tested_labels_per_ex
        if self.buckets_keys:

            def bucket_infos_to_str(bucket_infos):
                name = []
                for info, info_type in zip(bucket_infos, self.buckets_keys):
                    name.append(f"{info_type}={info}")
                return "/".join(name)

            columns_to_concatenate = [exs[key] for key in self.buckets_keys]
            buckets = [bucket_infos_to_str(bucket_infos) for bucket_infos in zip(*columns_to_concatenate)] * len(
                prompted_class_names
            )
        else:
            buckets = [None] * len(example_ids)

        return {
            "example_ids": example_ids,
            "true_labels": true_labels,
            "tested_labels": tested_labels,
            "relevance_scores": relevance_scores,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "pixel_attention_masks": pixel_attention_masks,
            "buckets": buckets,
        }