in vision/m4/models/vgpt2/evaluation_image_caption_matching_vgpt2.py [0:0]
def prepare_dataset(self, exs: Dict, **kwargs) -> Dict:
"""
Prepare batch of examples.
"""
prompt_template_id: int = kwargs["prompt_template_id"]
nb_exs = len(exs["id"])
nb_captions = len(self.caption_column_names)
nb_images = len(self.image_column_names)
# If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence [caption_0, caption_0, caption_1, caption_1]
general_dict = {"tested_prompts": [], "caption_ids": [], "image_ids": [], "ex_ids": []}
for idx_ex in range(nb_exs):
for caption_idx, caption_column in enumerate(self.caption_column_names):
for image_idx in range(nb_images):
tested_prompt = self._create_example_prompt(
prompt_template_id=prompt_template_id,
caption=exs[caption_column][idx_ex],
)
general_dict["tested_prompts"].append(tested_prompt)
general_dict["caption_ids"].append(caption_idx)
general_dict["image_ids"].append(image_idx)
general_dict["ex_ids"].append(exs["id"][idx_ex])
tot_texts = [
self._create_prefix_prompt(prompt_template_id=prompt_template_id) + tested_prompt
for tested_prompt in general_dict["tested_prompts"]
]
tot_texts = [text.strip() for text in tot_texts]
# Tokenize and masks
tokens = self.tokenizer(
tot_texts,
return_tensors="pt",
truncation=True,
max_length=self.tokenizer_max_seq_len,
padding=True,
add_special_tokens=False,
)
general_dict["input_ids"] = [tokens.input_ids[idx] for idx in range(len(tot_texts))]
general_dict["attention_mask"] = [tokens.attention_mask[idx] for idx in range(len(tot_texts))]
# If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence image_0, image_1, image_0, image_1
pixel_values_dict = {"pixel_values": [], "caption_ids": [], "image_ids": [], "ex_ids": []}
for idx_ex in range(nb_exs):
for caption_idx in range(nb_captions):
for image_idx, col in enumerate(self.image_column_names):
pixel_values_dict["pixel_values"].append(self.image_transform(exs[col][idx_ex]).unsqueeze(0))
pixel_values_dict["caption_ids"].append(caption_idx)
pixel_values_dict["image_ids"].append(image_idx)
pixel_values_dict["ex_ids"].append(exs["id"][idx_ex])
# ---- Sanity check ----
assert pixel_values_dict["ex_ids"] == general_dict["ex_ids"]
nb_combinations = nb_captions * nb_images
sample_pixel_captions_ids = pixel_values_dict["caption_ids"][:nb_combinations]
sample_pixel_image_ids = pixel_values_dict["image_ids"][:nb_combinations]
sample_general_captions_ids = general_dict["caption_ids"][:nb_combinations]
sample_general_image_ids = general_dict["image_ids"][:nb_combinations]
self.captions_images_order_per_ex
for idx in range(nb_combinations):
expected_caption_idx, expected_image_idx = self.captions_images_order_per_ex[idx]
assert sample_pixel_captions_ids[idx] == expected_caption_idx
assert sample_general_captions_ids[idx] == expected_caption_idx
assert sample_pixel_image_ids[idx] == expected_image_idx
assert sample_general_image_ids[idx] == expected_image_idx
# ---- Sanity check ----
general_dict["ex_ids"] = self._split_array(general_dict["ex_ids"], nb_exs)
general_dict["caption_ids"] = self._split_array(general_dict["caption_ids"], nb_exs)
general_dict["image_ids"] = self._split_array(general_dict["image_ids"], nb_exs)
general_dict["input_ids"] = self._split_array(general_dict["input_ids"], nb_exs)
pixel_values_dict["pixel_values"] = self._split_array(pixel_values_dict["pixel_values"], nb_exs)
general_dict["attention_mask"] = self._split_array(general_dict["attention_mask"], nb_exs)
return {
"example_ids": general_dict["ex_ids"],
"caption_ids": general_dict["caption_ids"],
"image_ids": general_dict["image_ids"],
"input_ids": general_dict["input_ids"],
"attention_mask": general_dict["attention_mask"],
"pixel_values": pixel_values_dict["pixel_values"],
}