vision/m4/models/vgpt2/evaluation_captioning_in_context_vgpt2.py (456 lines of code) (raw):

import os import random import re from typing import Dict, List, Optional import numpy as np import torch from accelerate.utils import extract_model_from_parallel from datasets import Dataset from deepspeed.runtime.engine import DeepSpeedEngine from transformers import AutoTokenizer from m4.evaluation.config import ShotSelectionMode from m4.evaluation.custom_metrics.unfolded_image_captioning_metrics import ImageCaptioningMetrics from m4.evaluation.tasks import BaseTaskImageCaptioning, Predictor from m4.evaluation.utils import EvaluationVersion from m4.training.packing import get_splitted_images_and_corresponding_text from m4.training.utils import ( FAKE_TOKEN_AROUND_IMAGE_V1, FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, build_image_transform, ) class Vgpt2ImageCaptioningInContext(BaseTaskImageCaptioning): model_class: str = "VGPT2LMHeadModel" predictor_class: Predictor = Predictor.in_contexter target_keys: List[str] = ["reference_captions", "example_ids"] stop_words = ["Caption", "Description", "User", "Image", "task", "<end_of_utterance>", "<row_", "apiro", "\u2500lrow_", "row_1"] tokenizer_max_seq_len = 1024 prompt_templates_dict = { 0: { "prefix": None, "example": "<image>Output: {caption}\n", }, 1: { "prefix": "{bos_token}", "example": "<image>Output: {caption}\n", }, 2: { "prefix": ( "{bos_token}This is a conversation between a human, User and an intelligent visual AI, Bot. The" " user sends images, and Bot describes the images sent by the user.\n" ), "example": "User:<image>\nBot: {caption}\n", }, 3: { "prefix": ( "{bos_token}This is a conversation between a human, User, and an intelligent visual AI, Bot. The" " user sends images, and Bot describes them. The bot" " should reply as accurately as possible.\n" ), "example": "User:<image>\nBot: {caption}\n", }, 4: { "prefix": "{bos_token}", "example": "Image to describe:<image>Description: {caption}\n", }, 5: { "prefix": None, "example": "{bos_token}<image>{caption}{eos_token}", }, 6: { "prefix": "{bos_token}Instruction: provide a short caption of the input image.\n", "example": "Image:<image>Image description: {caption}\n", }, 7: { "prefix": "{bos_token}", "example": "Image:<image>Caption: {caption}\n", }, 8: { "prefix": "{bos_token}Instruction: caption the image in details.\n", "example": "Image to caption:<image>Image caption: {caption}\n", }, } prompt_templates_dict_instruct = { 7: { "prefix": "{bos_token}", "example": ( "User:<image>Describe the image briefly.<end_of_utterance>\nAssistant: {caption}<end_of_utterance>\n" ), }, } bool_instruct_templates = False def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.tokenizer_name = kwargs.pop("tokenizer_name") evaluation_version = kwargs.pop("evaluation_version") tokenizer_use_fast = kwargs.pop("tokenizer_use_fast", False) self.vision_encoder_max_image_size = kwargs.pop("vision_encoder_max_image_size") vision_encoder_type = kwargs.pop("vision_encoder_type") self.image_seq_len = kwargs.pop("image_seq_len") self.image_transform = build_image_transform( max_image_size=self.vision_encoder_max_image_size, image_size=None, eval=True, vision_encoder_type=vision_encoder_type, ) self.scale_up_images = kwargs.pop("scale_up_images") self.image_size_after_scaling = kwargs.pop("image_size_after_scaling") self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_name, truncation_side="left", use_fast=tokenizer_use_fast, token=os.getenv("HF_TOKEN", True), ) self.tokenizer.padding_side = "left" self.image_token = IMAGE_TOKEN if evaluation_version == EvaluationVersion.v1: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V1 elif evaluation_version == EvaluationVersion.v2: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V2 else: raise ValueError(f"Invalid evaluation version: {evaluation_version}") def simpler_get_splitted_images_and_corresponding_text(self, image): splitted_images_array, text_splitted_images = get_splitted_images_and_corresponding_text( image=image, vision_encoder_max_image_size=self.vision_encoder_max_image_size, max_image_size=self.image_size_after_scaling, pre_split_scale_up_max=None, pre_split_scale_up_frequency=None, image_seq_len=self.image_seq_len, # Any value sufficiently high such that the image will always be resized to max_image_size scale_up_factor=100 if self.scale_up_images else 1, ) return splitted_images_array, text_splitted_images def get_info_from_dataset(self, dataset): pass def get_data_collator(self, **kwargs): def data_collator(batch): exs = {key: [ex[key] for ex in batch] for key in batch[0].keys()} batch = self.prepare_dataset(exs, **kwargs) return batch return data_collator def prepare_dataset(self, exs: Dict, **kwargs) -> Dict: """ Prepare batch of examples. """ support_dataset: Dataset = kwargs["support_dataset"] support_dataset_vision_encoder_embeddings: Optional[np.ndarray] = kwargs.get( "support_dataset_vision_encoder_embeddings", None ) num_shots: int = kwargs["num_shots"] shot_selection_mode: ShotSelectionMode = kwargs["shot_selection_mode"] prompt_template_id: int = kwargs["prompt_template_id"] nb_exs = len(exs["id"]) def retrieve_idx_closest_examples(ref_embedding, embeddings_to_compare, num_examples): "Returns the indices of the `num_examples` closest embeddings in ascending order" sim = np.dot(embeddings_to_compare, ref_embedding) # We can achieve linear complexity because we don't need to sort all the numbers, # but only find the `num_examples` largest ones idx_closest_ex = np.argpartition(sim, -num_examples)[-num_examples:] idx_closest_ex = idx_closest_ex[np.argsort(sim[idx_closest_ex])].tolist() return idx_closest_ex if (shot_selection_mode == ShotSelectionMode.random) or (num_shots == 0): idx_shots = [random.sample(range(len(support_dataset)), num_shots) for _ in range(nb_exs)] elif shot_selection_mode == ShotSelectionMode.first_without_image: idx_shots = [list(range(num_shots)) for _ in range(nb_exs)] else: idx_shots = [ retrieve_idx_closest_examples(ref_embedding, support_dataset_vision_encoder_embeddings, num_shots) for ref_embedding in exs["vision_encoder_embeddings"] ] # Prepare text shots # These are the priming text shots - size: batch_size texts_shots = [ "".join( [ self._create_example_prompt( prompt_template_id=prompt_template_id, caption=random.choice(support_dataset[idx_shot][self.reference_captions_column_name]), image=support_dataset[idx_shot][self.image_column_name], context=( support_dataset[idx_shot][self.context_column_name] if self.context_column_name else None ), without_image=shot_selection_mode == ShotSelectionMode.first_without_image, eos_token=self.tokenizer.eos_token, ) for idx_shot in idx_shots_ex ] ) for idx_shots_ex in idx_shots ] # These are the tested example - size: batch_size tested_exs = [ self._create_example_prompt( prompt_template_id=prompt_template_id, image=exs[self.image_column_name][idx], context=exs[self.context_column_name][idx] if self.context_column_name else None, eos_token="", ) for idx in range(nb_exs) ] if self.bool_instruct_templates: tested_exs = [ex[: -len("<end_of_utterance>\n")].strip() for ex in tested_exs] # These are the concatenation of the priming text shots and tested example - size: batch_siz tot_texts = [ self._create_prefix_prompt(prompt_template_id=prompt_template_id) + text_shot + tested_ex for text_shot, tested_ex in zip(texts_shots, tested_exs) ] tot_texts = [text.strip() for text in tot_texts] # Tokenize and masks tokens = self.tokenizer( tot_texts, return_tensors="pt", truncation=True, max_length=self.tokenizer_max_seq_len, padding=True, add_special_tokens=False, ) input_ids = [tokens.input_ids[idx] for idx in range(len(tot_texts))] attention_mask = [tokens.attention_mask[idx] for idx in range(len(tot_texts))] # Prepare image shots # These are the priming image shots - size: batch_size if shot_selection_mode == ShotSelectionMode.first_without_image: pixel_values_shots = [[] for _ in range(nb_exs)] else: pixel_values_shots = [ [ self.image_transform(sub_image) for idx_shot in idx_shots_ex for sub_image in self.simpler_get_splitted_images_and_corresponding_text( image=support_dataset[idx_shot][self.image_column_name], )[0] ] for idx_shots_ex in idx_shots ] # These are the tested images - size: batch_size tested_pixel_values = [ [ self.image_transform(sub_image) for sub_image in self.simpler_get_splitted_images_and_corresponding_text(image=image)[0] ] for image in exs[self.image_column_name] ] # These are the concatenation of the priming image shots and tested images - size: batch_size pixel_values = [] pixel_attention_masks = [] for pv_shots, pv in zip(pixel_values_shots, tested_pixel_values): num_images = len(pv_shots) + len(pv) max_height = max([im.size(1) for im in pv_shots] + [im.size(1) for im in pv]) max_width = max([im.size(2) for im in pv_shots] + [im.size(2) for im in pv]) padded_image_tensor = torch.zeros(num_images, 3, max_height, max_width) padded_pixel_attention_masks = torch.zeros(num_images, max_height, max_width, dtype=torch.bool) for idx, im in enumerate(pv_shots + pv): im_height, im_width = im.size(1), im.size(2) padded_image_tensor[idx, :, :im_height, :im_width] = im padded_pixel_attention_masks[idx, :im_height, :im_width] = True pixel_values.append(padded_image_tensor) pixel_attention_masks.append(padded_pixel_attention_masks) example_ids: List[int] = exs["id"] reference_captions = exs[self.reference_captions_column_name] if isinstance(reference_captions[0], str): reference_captions = [[ref_cap] for ref_cap in reference_captions] return { "example_ids": example_ids, "reference_captions": reference_captions, "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values, "pixel_attention_masks": pixel_attention_masks, } def _create_example_prompt(self, prompt_template_id, image, eos_token, caption="", context=None, without_image=False): if self.bool_instruct_templates: prompt_templates_dict = self.prompt_templates_dict_instruct else: prompt_templates_dict = self.prompt_templates_dict prompt_template = prompt_templates_dict[prompt_template_id]["example"] prompt_kwargs = {} prompt = prompt_template.format( bos_token=self.tokenizer.bos_token, eos_token=eos_token, # For the `eos_token`, the case is different than `bos_token`: when we include bos/eos in the shots, # both of them are always here (thus the usage of tokenizer.bos_token), but for the qeury example, # we add a `bos_token`, but not an `eos_token` to let the model continue context=context, caption=caption, **prompt_kwargs, ) prompt = prompt.replace("<image>", "<IMAGE>") _, text_splitted_images = self.simpler_get_splitted_images_and_corresponding_text(image=image) prompt = prompt.replace("<IMAGE>", text_splitted_images, 1) return prompt def _create_prefix_prompt(self, prompt_template_id): if self.bool_instruct_templates: prompt_templates_dict = self.prompt_templates_dict_instruct else: prompt_templates_dict = self.prompt_templates_dict prompt_template = prompt_templates_dict[prompt_template_id]["prefix"] if prompt_template is None: return "" else: prompt = prompt_template.format( bos_token=self.tokenizer.bos_token, eos_token=self.tokenizer.eos_token, ) return prompt def generate_tokens(self, **kwargs): # Flamingo: Beam search with a beam size of 3 model = kwargs["model"] input_ids = torch.stack(kwargs["input_ids"]).to(model.device) attention_mask = torch.stack(kwargs["attention_mask"]).to(model.device) total_batch_size = len(kwargs["pixel_values"]) max_num_images = max([i.size(0) for i in kwargs["pixel_values"]]) max_height = max([i.size(2) for i in kwargs["pixel_values"]]) max_width = max([i.size(3) for i in kwargs["pixel_values"]]) pixel_values = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) pixel_attention_mask = torch.zeros(total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool) for idx, (sample_images, sample_pixel_attention_mask) in enumerate( zip(kwargs["pixel_values"], kwargs["pixel_attention_masks"]) ): im_batch_height, im_batch_width = sample_images.size()[2:] pixel_values[idx, : sample_images.shape[0], :, :im_batch_height, :im_batch_width] = sample_images pixel_attention_mask[idx, : sample_pixel_attention_mask.shape[0], :im_batch_height, :im_batch_width] = ( sample_pixel_attention_mask ) pixel_values = pixel_values.to(model.device) pixel_attention_mask = pixel_attention_mask.to(model.device) num_beams = kwargs["num_beams"] no_repeat_ngram_size = kwargs["no_repeat_ngram_size"] max_new_tokens = kwargs["max_new_tokens"] bad_words = ["\n", "\n\n", self.image_token, self.token_around_image] bad_words_ids = self.tokenizer(bad_words, add_special_tokens=False)["input_ids"] unwrapped_model = extract_model_from_parallel(model) is_deepspeed_model = isinstance(model, DeepSpeedEngine) if is_deepspeed_model: if model.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in model.module.modules(): module._parameters._in_forward = True pass with torch.no_grad(): generated_tokens = unwrapped_model.generate( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_new_tokens, bad_words_ids=bad_words_ids, use_cache=True, early_stopping=True, synced_gpus=is_deepspeed_model, ) generated_tokens = generated_tokens[:, input_ids.shape[1] :] return generated_tokens def format_tokens_to_texts(self, tokens) -> List[str]: texts = self.tokenizer.batch_decode(tokens, skip_special_tokens=True) stop_words_pattern = r"|".join(self.stop_words) texts = [re.split(stop_words_pattern, text)[0] for text in texts] return texts def add_batch_metric(self, metric, **kwargs): generated_tokens = self.generate_tokens(**kwargs) generated_captions = self.format_tokens_to_texts(generated_tokens) metric.add_batch( generated_captions=generated_captions, **{key: kwargs[key] for key in self.target_keys}, ) return metric class TextCapsVgpt2ImageCaptioningInContextTextGenMetrics(Vgpt2ImageCaptioningInContext): dataset_name: str = "HuggingFaceM4/TextCaps" metric_name: str = "UnfoldedImageCaptioningMetrics" metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, ImageCaptioningMetrics.SPICE, ] } default_query_split_name: str = "validation" default_support_split_name: str = "train" image_column_name: str = "image" reference_captions_column_name: str = "reference_strs" class TextCapsVgpt2ImageCaptioningInContextBleuCiderMeteorRouge(TextCapsVgpt2ImageCaptioningInContextTextGenMetrics): metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, ] } class TextCapsSampleVgpt2ImageCaptioningInContextTextGenMetrics(TextCapsVgpt2ImageCaptioningInContextTextGenMetrics): dataset_name: str = "HuggingFaceM4/TextCaps-Sample" class TextCapsSampleVgpt2ImageCaptioningInContextBleuCiderMeteorRouge( TextCapsVgpt2ImageCaptioningInContextBleuCiderMeteorRouge ): dataset_name: str = "HuggingFaceM4/TextCaps-Sample" class CommonGenVgpt2ImageCaptioningInContextTextGenMetrics(Vgpt2ImageCaptioningInContext): dataset_name: str = "HuggingFaceM4/common_gen" metric_name: str = "UnfoldedImageCaptioningMetrics" metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, ImageCaptioningMetrics.SPICE, ] } default_query_split_name: str = "validation" default_support_split_name: str = "train" image_column_name: str = "image" context_column_name: str = "concepts" reference_captions_column_name: str = "target" stop_words = ["Input", "Output"] prompt_templates_dict = { 0: { "prefix": None, "example": "<image>Input: {context}. Output: {caption}", } } class CommonGenVgpt2ImageCaptioningInContextBleuCiderMeteorRouge(CommonGenVgpt2ImageCaptioningInContextTextGenMetrics): metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, ] } class NoCapsVgpt2ImageCaptioningInContextTextGenMetrics(Vgpt2ImageCaptioningInContext): dataset_name: str = "HuggingFaceM4/NoCaps" metric_name: str = "UnfoldedImageCaptioningMetrics" metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, # ImageCaptioningMetrics.SPICE, ] } default_query_split_name: str = "validation" default_support_split_name: str = "train" # This does not exist yet... it would require adding a training split to the dataset (see `create_sample_evaluation_datasets_simplified.py`) image_column_name: str = "image" reference_captions_column_name: str = "annotations_captions" class NoCapsSampleVgpt2ImageCaptioningInContextTextGenMetrics(NoCapsVgpt2ImageCaptioningInContextTextGenMetrics): dataset_name: str = "HuggingFaceM4/NoCaps-Sample" class CocoVgpt2ImageCaptioningInContextBleuCiderMeteorRouge(Vgpt2ImageCaptioningInContext): dataset_name: str = "HuggingFaceM4/COCO" dataset_config = "2014_captions" metric_name: str = "UnfoldedImageCaptioningMetrics" metric_kwargs = { "metrics": [ ImageCaptioningMetrics.BLEU_4, ImageCaptioningMetrics.CIDER, ImageCaptioningMetrics.METEOR, ImageCaptioningMetrics.ROUGE_L, ] } default_query_split_name: str = "validation" default_support_split_name: str = "train" image_column_name: str = "image" reference_captions_column_name: str = "sentences_raw" class CocoSampleVgpt2ImageCaptioningInContextBleuCiderMeteorRouge( CocoVgpt2ImageCaptioningInContextBleuCiderMeteorRouge ): dataset_name: str = "HuggingFaceM4/COCO-2014_captions-Sample" dataset_config = None class IIIT5KVgpt2ImageCaptioningInContextExactMatch(Vgpt2ImageCaptioningInContext): dataset_name: str = "HuggingFaceM4/IIIT-5K" metric_name: str = "UnfoldedImageCaptioningMetrics" metric_kwargs = {"metrics": [ImageCaptioningMetrics.EXACT_MATCH]} default_query_split_name: str = "test" default_support_split_name: str = "train" image_column_name: str = "image" reference_captions_column_name: str = "label" stop_words = ["A photo"] prompt_templates_dict = { 0: { "prefix": None, "example": "<image>A photo where it is written {caption}", } } class IIIT5KSampleVgpt2ImageCaptioningInContextExactMatch(IIIT5KVgpt2ImageCaptioningInContextExactMatch): dataset_name: str = "HuggingFaceM4/IIIT-5K-Sample"