in vision/m4/models/vgpt2/evaluation_perplexity_in_context_vgpt2.py [0:0]
def prepare_webdoc_ds(self, exs: Dict) -> Dict:
images_batch = exs[self.image_column_name]
texts_batch = exs[self.text_column_name]
tokenizer = self.tokenizer
last_was_image = False
all_images = []
all_texts = []
for raw_images, raw_texts in zip(images_batch, texts_batch):
inds_of_texts_to_split = [
i
for i, text in enumerate(raw_texts)
if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
]
if inds_of_texts_to_split:
splitted_raw_images, splitted_raw_texts = [], []
previous_i = 0
for i in inds_of_texts_to_split:
splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
part1, part2 = splitting[0], splitting[-1]
sub_doc_images = raw_images[previous_i:i] + [None]
sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None
continue
splitted_raw_images.append(sub_doc_images)
splitted_raw_texts.append(sub_doc_texts)
if part2.strip() == "":
previous_i = i + 1
else:
raw_texts[i] = part2.strip()
previous_i = i
if previous_i < len(raw_images) and any(raw_images[previous_i:]):
splitted_raw_images.append(raw_images[previous_i:])
splitted_raw_texts.append(raw_texts[previous_i:])
else:
splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]
# Sanity check
if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
raise ValueError(
"Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
" Something core went wrong during the splitting and needs to be fixed."
)
for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
images, web_text = [], ""
for image, text in zip(s_r_ims, s_r_txts):
if text is None and image is None:
continue
if image is not None:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
images.append(self.image_transform(image))
last_was_image = True
elif text is not None:
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
last_was_image = False
else:
web_text += f" {text}" if web_text != "" else text
if last_was_image:
web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"
web_text = web_text.strip(" ")
# This is mostly a sanity check. Cases like that should not happen at that point.
if web_text == "" or len(images) == 0:
continue
images = torch.stack(images)
all_images.append(images)
web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
if self.add_end_of_doc_token:
web_text_ids += [tokenizer.eos_token_id]
if self.add_begin_of_doc_token:
web_text_ids = [tokenizer.bos_token_id] + web_text_ids
all_texts.append(web_text_ids)
output_input_ids = []
output_images = []
output_attention_masks = []
for images, text in zip(all_images, all_texts):
padded_input_ids = [tokenizer.pad_token_id] * self.tokenizer_max_seq_len
unpadded_seq_len = len(text)
padded_input_ids[:unpadded_seq_len] = text[: self.tokenizer_max_seq_len]
attention_mask = torch.zeros((self.tokenizer_max_seq_len,), dtype=torch.long)
attention_mask[:unpadded_seq_len] = 1
image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, self.max_num_images)
current_images = images[:local_max_num_images]
padded_image_tensor = torch.zeros(self.max_num_images, *current_images.size()[1:])
padded_image_tensor[: current_images.size(0)] = current_images
output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(padded_input_ids))
output_attention_masks.append(attention_mask)
output_input_ids = torch.stack(output_input_ids)
output_images = torch.stack(output_images)
output_attention_masks = torch.stack(output_attention_masks)
example_ids: List[int] = exs["id"]
return {
"example_ids": example_ids,
"input_ids": [input_ids for input_ids in output_input_ids],
"attention_mask": [attention_masks for attention_masks in output_attention_masks],
"pixel_values": [pixels for pixels in output_images],
}