def prepare_webdoc_ds()

in vision/m4/models/vgpt2/evaluation_perplexity_in_context_vgpt2.py [0:0]


    def prepare_webdoc_ds(self, exs: Dict) -> Dict:
        images_batch = exs[self.image_column_name]
        texts_batch = exs[self.text_column_name]

        tokenizer = self.tokenizer

        last_was_image = False
        all_images = []
        all_texts = []
        for raw_images, raw_texts in zip(images_batch, texts_batch):
            inds_of_texts_to_split = [
                i
                for i, text in enumerate(raw_texts)
                if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
            ]
            if inds_of_texts_to_split:
                splitted_raw_images, splitted_raw_texts = [], []
                previous_i = 0
                for i in inds_of_texts_to_split:
                    splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
                    part1, part2 = splitting[0], splitting[-1]

                    sub_doc_images = raw_images[previous_i:i] + [None]
                    sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
                    if not any(sub_doc_images):  # This can happen if all images in raw_images[0:i] are all None
                        continue

                    splitted_raw_images.append(sub_doc_images)
                    splitted_raw_texts.append(sub_doc_texts)

                    if part2.strip() == "":
                        previous_i = i + 1
                    else:
                        raw_texts[i] = part2.strip()
                        previous_i = i

                if previous_i < len(raw_images) and any(raw_images[previous_i:]):
                    splitted_raw_images.append(raw_images[previous_i:])
                    splitted_raw_texts.append(raw_texts[previous_i:])

            else:
                splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]

            # Sanity check
            if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
                raise ValueError(
                    "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
                    " Something core went wrong during the splitting and needs to be fixed."
                )

            for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
                images, web_text = [], ""
                for image, text in zip(s_r_ims, s_r_txts):
                    if text is None and image is None:
                        continue

                    if image is not None:
                        web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
                        images.append(self.image_transform(image))
                        last_was_image = True
                    elif text is not None:
                        if last_was_image:
                            web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
                            last_was_image = False
                        else:
                            web_text += f" {text}" if web_text != "" else text

                if last_was_image:
                    web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"

                web_text = web_text.strip(" ")

                # This is mostly a sanity check. Cases like that should not happen at that point.
                if web_text == "" or len(images) == 0:
                    continue

                images = torch.stack(images)
                all_images.append(images)

                web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
                if self.add_end_of_doc_token:
                    web_text_ids += [tokenizer.eos_token_id]

                if self.add_begin_of_doc_token:
                    web_text_ids = [tokenizer.bos_token_id] + web_text_ids
                all_texts.append(web_text_ids)

        output_input_ids = []
        output_images = []
        output_attention_masks = []
        for images, text in zip(all_images, all_texts):
            padded_input_ids = [tokenizer.pad_token_id] * self.tokenizer_max_seq_len
            unpadded_seq_len = len(text)
            padded_input_ids[:unpadded_seq_len] = text[: self.tokenizer_max_seq_len]

            attention_mask = torch.zeros((self.tokenizer_max_seq_len,), dtype=torch.long)
            attention_mask[:unpadded_seq_len] = 1

            image_count = padded_input_ids.count(self.image_token_id)
            local_max_num_images = min(image_count, self.max_num_images)

            current_images = images[:local_max_num_images]

            padded_image_tensor = torch.zeros(self.max_num_images, *current_images.size()[1:])
            padded_image_tensor[: current_images.size(0)] = current_images

            output_images.append(padded_image_tensor)
            output_input_ids.append(torch.tensor(padded_input_ids))

            output_attention_masks.append(attention_mask)

        output_input_ids = torch.stack(output_input_ids)
        output_images = torch.stack(output_images)
        output_attention_masks = torch.stack(output_attention_masks)

        example_ids: List[int] = exs["id"]
        return {
            "example_ids": example_ids,
            "input_ids": [input_ids for input_ids in output_input_ids],
            "attention_mask": [attention_masks for attention_masks in output_attention_masks],
            "pixel_values": [pixels for pixels in output_images],
        }