in vision/m4/models/vgpt2/evaluation_open_ended_vqa_in_context_vgpt2.py [0:0]
def prepare_dataset(self, exs: Dict, **kwargs) -> Dict:
"""
Prepare batch of examples.
"""
support_dataset: Dataset = kwargs["support_dataset"]
support_dataset_vision_encoder_embeddings: Optional[np.ndarray] = kwargs.get(
"support_dataset_vision_encoder_embeddings", None
)
num_shots: int = kwargs["num_shots"]
shot_selection_mode: ShotSelectionMode = kwargs["shot_selection_mode"]
prompt_template_id: int = kwargs["prompt_template_id"]
nb_exs = len(exs["id"])
multiple_images_dataset = isinstance(support_dataset[0][self.image_column_name], list)
def retrieve_idx_closest_examples(ref_embedding, embeddings_to_compare, num_examples):
"Returns the indices of the `num_examples` closest embeddings in ascending order"
sim = np.dot(embeddings_to_compare, ref_embedding)
# We can achieve linear complexity because we don't need to sort all the numbers,
# but only find the `num_examples` largest ones
idx_closest_ex = np.argpartition(sim, -num_examples)[-num_examples:]
idx_closest_ex = idx_closest_ex[np.argsort(sim[idx_closest_ex])].tolist()
return idx_closest_ex
if (shot_selection_mode == ShotSelectionMode.random) or (num_shots == 0):
idx_shots = [random.sample(range(len(support_dataset)), num_shots) for _ in range(nb_exs)]
elif shot_selection_mode == ShotSelectionMode.first_without_image:
idx_shots = [list(range(num_shots)) for _ in range(nb_exs)]
else:
idx_shots = [
retrieve_idx_closest_examples(ref_embedding, support_dataset_vision_encoder_embeddings, num_shots)
for ref_embedding in exs["vision_encoder_embeddings"]
]
# Prepare text shots
# These are the priming text shots - size: batch_size
texts_shots = [
"".join(
[
self._create_example_prompt(
prompt_template_id=prompt_template_id,
question=support_dataset[idx_shot][self.question_column_name],
answer=Counter(support_dataset[idx_shot][self.answers_column_name]).most_common(1)[0][0],
image=support_dataset[idx_shot][self.image_column_name],
eos_token=self.tokenizer.eos_token,
without_image=shot_selection_mode == ShotSelectionMode.first_without_image,
multiple_images_dataset=multiple_images_dataset,
contexts=(
[
(context_column_name, support_dataset[context_column_name][idx_shot])
for context_column_name in self.context_column_names
]
if self.context_column_names
else None
),
)
for idx_shot in idx_shots_ex
]
)
for idx_shots_ex in idx_shots
]
# These are the tested example - size: batch_size
tested_exs = [
self._create_example_prompt(
prompt_template_id=prompt_template_id,
question=question,
image=exs[self.image_column_name][idx_ex],
eos_token="",
multiple_images_dataset=multiple_images_dataset,
contexts=(
[
(context_column_name, exs[context_column_name][idx_ex])
for context_column_name in self.context_column_names
]
if self.context_column_names
else None
),
).strip()
for idx_ex, question in enumerate(exs[self.question_column_name])
]
if self.bool_instruct_templates:
tested_exs = [ex[: -len("<end_of_utterance>\n")].strip() for ex in tested_exs]
# These are the concatenation of the priming text shots and tested example - size: batch_size
tot_texts = [
self._create_prefix_prompt(prompt_template_id=prompt_template_id) + text_shot + tested_ex
for text_shot, tested_ex in zip(texts_shots, tested_exs)
]
tot_texts = [text.strip() for text in tot_texts]
# Tokenize and masks
tokens = self.tokenizer(
tot_texts,
return_tensors="pt",
truncation=True,
max_length=self.tokenizer_max_seq_len,
padding=True,
add_special_tokens=False,
)
input_ids = [tokens.input_ids[idx] for idx in range(len(tot_texts))]
attention_mask = [tokens.attention_mask[idx] for idx in range(len(tot_texts))]
# Prepare image shots
# These are the priming image shots - size: batch_size
if shot_selection_mode == ShotSelectionMode.first_without_image:
pixel_values_shots = [[] for _ in range(nb_exs)]
elif multiple_images_dataset:
pixel_values_shots = [
[
self.image_transform(sub_image)
for idx_shot in idx_shots_ex
for img in support_dataset[idx_shot][self.image_column_name]
for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=img)[0]
]
for idx_shots_ex in idx_shots
]
else:
pixel_values_shots = [
[
self.image_transform(sub_image)
for idx_shot in idx_shots_ex
for sub_image in self.simpler_get_splitted_images_and_corresponding_text(
image=support_dataset[idx_shot][self.image_column_name],
)[0]
]
for idx_shots_ex in idx_shots
]
# These are the tested images - size: batch_size
if multiple_images_dataset:
tested_pixel_values = [
[
self.image_transform(sub_image)
for image in images
for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=image)[0]
]
for images in exs[self.image_column_name]
]
else:
tested_pixel_values = [
[
self.image_transform(sub_image)
for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=image)[0]
]
for image in exs[self.image_column_name]
]
# These are the concatenation of the priming image shots and tested images - size: batch_size
pixel_values = []
pixel_attention_masks = []
for pv_shots, pv in zip(pixel_values_shots, tested_pixel_values):
num_images = len(pv_shots) + len(pv)
max_height = max([im.size(1) for im in pv_shots] + [im.size(1) for im in pv])
max_width = max([im.size(2) for im in pv_shots] + [im.size(2) for im in pv])
padded_image_tensor = torch.zeros(num_images, 3, max_height, max_width)
padded_pixel_attention_masks = torch.zeros(num_images, max_height, max_width, dtype=torch.bool)
for idx, im in enumerate(pv_shots + pv):
im_height, im_width = im.size(1), im.size(2)
padded_image_tensor[idx, :, :im_height, :im_width] = im
padded_pixel_attention_masks[idx, :im_height, :im_width] = True
pixel_values.append(padded_image_tensor)
pixel_attention_masks.append(padded_pixel_attention_masks)
example_ids: List[int] = exs["id"]
answers = exs[self.answers_column_name]
if self.buckets_keys:
def bucket_infos_to_str(bucket_infos):
name = []
for info, info_type in zip(bucket_infos, self.buckets_keys):
name.append(f"{info_type}={info}")
return "/".join(name)
columns_to_concatenate = [exs[key] for key in self.buckets_keys]
buckets = [bucket_infos_to_str(bucket_infos) for bucket_infos in zip(*columns_to_concatenate)]
else:
buckets = [""] * len(example_ids)
return {
"example_ids": example_ids,
"answers": answers,
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_attention_masks": pixel_attention_masks,
"buckets": buckets,
}