vision/m4/models/vgpt2/evaluation_image_caption_matching_vgpt2.py (237 lines of code) (raw):

import os from itertools import chain from typing import Dict, List, Optional import torch from torch.nn import CrossEntropyLoss from transformers import AutoTokenizer from m4.evaluation.custom_metrics.image_caption_matching_metrics import MetricsImageCaptionMatching from m4.evaluation.tasks import BaseTaskImageCaptionMatching, Predictor from m4.evaluation.utils import EvaluationVersion from m4.training.utils import ( FAKE_TOKEN_AROUND_IMAGE_V1, FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, build_image_transform, ) class Vgpt2ImageCaptionMatching(BaseTaskImageCaptionMatching): model_class: str = "VGPT2LMHeadModel" predictor_class: Predictor = Predictor.in_contexter target_keys: List[str] = ["example_ids", "caption_ids", "image_ids"] buckets_keys: List[str] = [] # Buckets are optionally populated for classification in context. They are only useful when it is useful to get results bucket. A bucket is typically a certain slice of the dataset (for instance, all instances where age=30). mapping_class_names_to_prompt_names: Optional[Dict[str, str]] = None prompt_templates_dict: Dict[int, Dict[str, str]] = {} mapping_class_prompt_name_id_to_prompt_template_id: Optional[Dict[int, int]] = None tokenizer_max_seq_len = 1024 def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.tokenizer_name = kwargs.pop("tokenizer_name") evaluation_version = kwargs.pop("evaluation_version") tokenizer_use_fast = kwargs.pop("tokenizer_use_fast", False) image_size = kwargs.pop("image_size") vision_encoder_type = kwargs.pop("vision_encoder_type") self.image_seq_len = kwargs.pop("image_seq_len") self.image_transform = build_image_transform( max_image_size=image_size, image_size=None, eval=True, vision_encoder_type=vision_encoder_type ) self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_name, truncation_side="left", use_fast=tokenizer_use_fast, token=os.getenv("HF_TOKEN", True) ) self.tokenizer.padding_side = "left" self.image_token = IMAGE_TOKEN if evaluation_version == EvaluationVersion.v1: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V1 elif evaluation_version == EvaluationVersion.v2: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V2 else: raise ValueError(f"Invalid evaluation version: {evaluation_version}") nb_captions = len(self.caption_column_names) nb_images = len(self.image_column_names) self.captions_images_order_per_ex = [ (caption_idx, image_idx) for caption_idx in range(nb_captions) for image_idx in range(nb_images) ] def get_info_from_dataset(self, dataset): pass def get_data_collator(self, **kwargs): def data_collator(batch): exs = {key: [ex[key] for ex in batch] for key in batch[0].keys()} batch = self.prepare_dataset(exs, **kwargs) return batch return data_collator def _split_array(self, array, nb_combinations): total_elements = len(array) elements_per_combination = total_elements // nb_combinations splitted_array = [ array[i : i + elements_per_combination] for i in range(0, total_elements, elements_per_combination) ] return splitted_array def prepare_dataset(self, exs: Dict, **kwargs) -> Dict: """ Prepare batch of examples. """ prompt_template_id: int = kwargs["prompt_template_id"] nb_exs = len(exs["id"]) nb_captions = len(self.caption_column_names) nb_images = len(self.image_column_names) # If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence [caption_0, caption_0, caption_1, caption_1] general_dict = {"tested_prompts": [], "caption_ids": [], "image_ids": [], "ex_ids": []} for idx_ex in range(nb_exs): for caption_idx, caption_column in enumerate(self.caption_column_names): for image_idx in range(nb_images): tested_prompt = self._create_example_prompt( prompt_template_id=prompt_template_id, caption=exs[caption_column][idx_ex], ) general_dict["tested_prompts"].append(tested_prompt) general_dict["caption_ids"].append(caption_idx) general_dict["image_ids"].append(image_idx) general_dict["ex_ids"].append(exs["id"][idx_ex]) tot_texts = [ self._create_prefix_prompt(prompt_template_id=prompt_template_id) + tested_prompt for tested_prompt in general_dict["tested_prompts"] ] tot_texts = [text.strip() for text in tot_texts] # Tokenize and masks tokens = self.tokenizer( tot_texts, return_tensors="pt", truncation=True, max_length=self.tokenizer_max_seq_len, padding=True, add_special_tokens=False, ) general_dict["input_ids"] = [tokens.input_ids[idx] for idx in range(len(tot_texts))] general_dict["attention_mask"] = [tokens.attention_mask[idx] for idx in range(len(tot_texts))] # If we have caption_column_names = ["caption_0", "caption_1"] and image_column_names= ["image_0", "image_1"]. We get the sequence image_0, image_1, image_0, image_1 pixel_values_dict = {"pixel_values": [], "caption_ids": [], "image_ids": [], "ex_ids": []} for idx_ex in range(nb_exs): for caption_idx in range(nb_captions): for image_idx, col in enumerate(self.image_column_names): pixel_values_dict["pixel_values"].append(self.image_transform(exs[col][idx_ex]).unsqueeze(0)) pixel_values_dict["caption_ids"].append(caption_idx) pixel_values_dict["image_ids"].append(image_idx) pixel_values_dict["ex_ids"].append(exs["id"][idx_ex]) # ---- Sanity check ---- assert pixel_values_dict["ex_ids"] == general_dict["ex_ids"] nb_combinations = nb_captions * nb_images sample_pixel_captions_ids = pixel_values_dict["caption_ids"][:nb_combinations] sample_pixel_image_ids = pixel_values_dict["image_ids"][:nb_combinations] sample_general_captions_ids = general_dict["caption_ids"][:nb_combinations] sample_general_image_ids = general_dict["image_ids"][:nb_combinations] self.captions_images_order_per_ex for idx in range(nb_combinations): expected_caption_idx, expected_image_idx = self.captions_images_order_per_ex[idx] assert sample_pixel_captions_ids[idx] == expected_caption_idx assert sample_general_captions_ids[idx] == expected_caption_idx assert sample_pixel_image_ids[idx] == expected_image_idx assert sample_general_image_ids[idx] == expected_image_idx # ---- Sanity check ---- general_dict["ex_ids"] = self._split_array(general_dict["ex_ids"], nb_exs) general_dict["caption_ids"] = self._split_array(general_dict["caption_ids"], nb_exs) general_dict["image_ids"] = self._split_array(general_dict["image_ids"], nb_exs) general_dict["input_ids"] = self._split_array(general_dict["input_ids"], nb_exs) pixel_values_dict["pixel_values"] = self._split_array(pixel_values_dict["pixel_values"], nb_exs) general_dict["attention_mask"] = self._split_array(general_dict["attention_mask"], nb_exs) return { "example_ids": general_dict["ex_ids"], "caption_ids": general_dict["caption_ids"], "image_ids": general_dict["image_ids"], "input_ids": general_dict["input_ids"], "attention_mask": general_dict["attention_mask"], "pixel_values": pixel_values_dict["pixel_values"], } def _create_example_prompt(self, prompt_template_id, caption): if self.bool_instruct_templates: prompt_templates_dict = self.prompt_templates_dict_instruct else: prompt_templates_dict = self.prompt_templates_dict prompt_template = prompt_templates_dict[prompt_template_id]["example"] template_kwargs = { "bos_token": self.tokenizer.bos_token, "eos_token": self.tokenizer.eos_token, "image_token": self.image_token * self.image_seq_len, "token_around_image": self.token_around_image, "caption": caption, } prompt = prompt_template.format(**template_kwargs) return prompt def _create_prefix_prompt(self, prompt_template_id): if self.bool_instruct_templates: prompt_templates_dict = self.prompt_templates_dict_instruct else: prompt_templates_dict = self.prompt_templates_dict prompt_template = prompt_templates_dict[prompt_template_id]["prefix"] if prompt_template is None: return "" else: prompt = prompt_template.format( bos_token=self.tokenizer.bos_token, eos_token=self.tokenizer.eos_token, ) return prompt def predict(self, **kwargs): model = kwargs["model"] input_ids = torch.stack(list(chain.from_iterable(kwargs["input_ids"]))).to(model.device) attention_mask = torch.stack(list(chain.from_iterable(kwargs["attention_mask"]))).to(model.device) pv = list(chain.from_iterable(kwargs["pixel_values"])) total_batch_size = len(pv) max_num_images = max([i.size(0) for i in pv]) max_height = max([i.size(2) for i in pv]) max_width = max([i.size(3) for i in pv]) pixel_values = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) pixel_attention_mask = torch.zeros(total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool) for idx, sample_images in enumerate(pv): im_batch_height, im_batch_width = sample_images.size()[2:] pixel_values[idx, :, :, :im_batch_height, :im_batch_width] = sample_images pixel_attention_mask[idx, :, :im_batch_height, :im_batch_width] = True pixel_values = pixel_values.to(model.device) pixel_attention_mask = pixel_attention_mask.to(model.device) with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, ) outputs.input_ids = input_ids outputs.attention_mask = attention_mask return outputs def format_model_outputs_to_predictions(self, outputs) -> torch.Tensor: batch_size = outputs.logits.shape[0] # Shift so that tokens < n predict n shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = outputs.input_ids[..., 1:].contiguous() shift_attention_mask = outputs.attention_mask[:, 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(reduction="none") log_probs = -loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) log_probs = log_probs.view(batch_size, -1) masked_log_probs = shift_attention_mask * log_probs score_per_example = masked_log_probs.sum(dim=-1) if self.length_normalize: score_per_example = score_per_example / shift_attention_mask.sum(dim=-1) nb_combinations = len(self.image_column_names) * len(self.caption_column_names) nb_exs = batch_size // nb_combinations splitted_scores_per_example = self._split_array(score_per_example.tolist(), nb_exs) return splitted_scores_per_example def add_batch_metric(self, metric, **kwargs): outputs = self.predict(**kwargs) splitted_scores_per_example = self.format_model_outputs_to_predictions(outputs) additional_args = {key: kwargs[key] for key in self.target_keys} metric.add_batch( splitted_scores_per_example=splitted_scores_per_example, **additional_args, ) return metric class WinogroundVgpt2ImageCaptionMatchingAccWithKLAndEntropy(Vgpt2ImageCaptionMatching): dataset_name: str = "facebook/winoground" metric_name: str = "ImageCaptionMatchingMetrics" metric_kwargs = { "metrics": [ MetricsImageCaptionMatching.TEXT_SCORE, MetricsImageCaptionMatching.IMAGE_SCORE, MetricsImageCaptionMatching.GROUP_SCORE, ] } # support split names are never used for this dataset default_query_split_name: str = "test" default_support_split_name: str = "test" test_support_split_name: str = "test" image_column_names: List[str] = ["image_0", "image_1"] id_column_name: str = "id" caption_column_names: List[str] = ["caption_0", "caption_1"] length_normalize: bool = True prompt_templates_dict = { 0: { "prefix": None, "example": "{token_around_image}{image_token}{token_around_image}{caption}", }, }