vision/m4/models/vgpt2/evaluation_perplexity_in_context_vgpt2.py (313 lines of code) (raw):

import os from typing import Dict, List import torch from accelerate.utils import extract_model_from_parallel from deepspeed.runtime.engine import DeepSpeedEngine from transformers import AutoTokenizer from m4.evaluation.custom_metrics.perplexity_metrics import MetricsPerplexity from m4.evaluation.tasks import BaseTask, Predictor from m4.evaluation.utils import EvaluationVersion from m4.training.types import DatasetTypes from m4.training.utils import ( FAKE_TOKEN_AROUND_IMAGE_V1, FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, build_image_transform, ) class Vgpt2PerplexityInContext(BaseTask): model_class: str = "VGPT2LMHeadModel" predictor_class: Predictor = Predictor.in_contexter target_keys: List[str] = ["example_ids"] image_column_name: str = None text_column_name: str = None context_column_name: str = None ds_type: DatasetTypes = DatasetTypes.IMAGE_CAPTION_PAIRS add_end_of_doc_token: bool = True add_begin_of_doc_token: bool = False tokenizer_max_seq_len = 1024 max_num_images = 70 def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.tokenizer_name = kwargs.pop("tokenizer_name") evaluation_version = kwargs.pop("evaluation_version") tokenizer_use_fast = kwargs.pop("tokenizer_use_fast", False) vision_encoder_type = kwargs.pop("vision_encoder_type") self.image_seq_len = kwargs.pop("image_seq_len") image_size = kwargs.pop("image_size") self.image_transform = build_image_transform( max_image_size=image_size, image_size=None, eval=True, vision_encoder_type=vision_encoder_type ) self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer_name, truncation_side="left", use_fast=tokenizer_use_fast, token=os.getenv("HF_TOKEN", True) ) self.image_token = IMAGE_TOKEN self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) self.eos_token = self.tokenizer.eos_token self.pad_token_id = self.tokenizer.pad_token_id if evaluation_version == EvaluationVersion.v1: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V1 elif evaluation_version == EvaluationVersion.v2: self.token_around_image = FAKE_TOKEN_AROUND_IMAGE_V2 else: raise ValueError(f"Invalid evaluation version: {evaluation_version}") raise NotImplementedError( "Padding for various size images has not been implemented for that class yet. Ask Victor to do it. He's" " unsure the last time we used this and as such, won't be spending time on something we might not even" " touch in the future." ) def get_info_from_dataset(self, dataset): pass def get_data_collator(self, **kwargs): def data_collator(batch): exs = {key: [ex[key] for ex in batch] for key in batch[0].keys()} batch = self.prepare_dataset(exs, **kwargs) return batch return data_collator def prepare_image_caption_pair_ds(self, exs: Dict) -> Dict: nb_exs = len(exs["id"]) tot_texts = [ self._create_image_caption_pair_prompt( caption=(exs[self.text_column_name][idx][0]), context=exs[idx][self.context_column_name] if self.context_column_name else None, ) for idx in range(nb_exs) ] # These are the tested example - size: batch_size tot_texts = [self._add_special_tokens_to_prompt(text) for text in tot_texts] tot_texts = [text.strip() for text in tot_texts] # Tokenize and masks tokens = self.tokenizer( tot_texts, return_tensors="pt", truncation=True, max_length=self.tokenizer_max_seq_len, padding=True, add_special_tokens=False, ) # These are the tested images - size: batch_size pixel_values = [self.image_transform(img).unsqueeze(0) for img in exs[self.image_column_name]] example_ids: List[int] = exs["id"] return { "example_ids": example_ids, "input_ids": [tokens.input_ids[idx] for idx in range(len(tot_texts))], "attention_mask": [tokens.attention_mask[idx] for idx in range(len(tot_texts))], "pixel_values": pixel_values, } def _create_image_caption_pair_prompt(self, caption="", context=None): if context is not None: raise NotImplementedError("Context not implemented for this task") prompt = f"{self.token_around_image}{self.image_token * self.image_seq_len}{self.token_around_image}{caption}" return prompt def _add_special_tokens_to_prompt(self, prompt): if self.add_end_of_doc_token: prompt = f"{prompt}{self.tokenizer.eos_token}" if self.add_begin_of_doc_token: prompt = f"{self.tokenizer.bos_token}{prompt}" return prompt def prepare_webdoc_ds(self, exs: Dict) -> Dict: images_batch = exs[self.image_column_name] texts_batch = exs[self.text_column_name] tokenizer = self.tokenizer last_was_image = False all_images = [] all_texts = [] for raw_images, raw_texts in zip(images_batch, texts_batch): inds_of_texts_to_split = [ i for i, text in enumerate(raw_texts) if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text ] if inds_of_texts_to_split: splitted_raw_images, splitted_raw_texts = [], [] previous_i = 0 for i in inds_of_texts_to_split: splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED") part1, part2 = splitting[0], splitting[-1] sub_doc_images = raw_images[previous_i:i] + [None] sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()] if not any(sub_doc_images): # This can happen if all images in raw_images[0:i] are all None continue splitted_raw_images.append(sub_doc_images) splitted_raw_texts.append(sub_doc_texts) if part2.strip() == "": previous_i = i + 1 else: raw_texts[i] = part2.strip() previous_i = i if previous_i < len(raw_images) and any(raw_images[previous_i:]): splitted_raw_images.append(raw_images[previous_i:]) splitted_raw_texts.append(raw_texts[previous_i:]) else: splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts] # Sanity check if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]: raise ValueError( "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`." " Something core went wrong during the splitting and needs to be fixed." ) for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts): images, web_text = [], "" for image, text in zip(s_r_ims, s_r_txts): if text is None and image is None: continue if image is not None: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}" images.append(self.image_transform(image)) last_was_image = True elif text is not None: if last_was_image: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}" last_was_image = False else: web_text += f" {text}" if web_text != "" else text if last_was_image: web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}" web_text = web_text.strip(" ") # This is mostly a sanity check. Cases like that should not happen at that point. if web_text == "" or len(images) == 0: continue images = torch.stack(images) all_images.append(images) web_text_ids = tokenizer.encode(web_text, add_special_tokens=False) if self.add_end_of_doc_token: web_text_ids += [tokenizer.eos_token_id] if self.add_begin_of_doc_token: web_text_ids = [tokenizer.bos_token_id] + web_text_ids all_texts.append(web_text_ids) output_input_ids = [] output_images = [] output_attention_masks = [] for images, text in zip(all_images, all_texts): padded_input_ids = [tokenizer.pad_token_id] * self.tokenizer_max_seq_len unpadded_seq_len = len(text) padded_input_ids[:unpadded_seq_len] = text[: self.tokenizer_max_seq_len] attention_mask = torch.zeros((self.tokenizer_max_seq_len,), dtype=torch.long) attention_mask[:unpadded_seq_len] = 1 image_count = padded_input_ids.count(self.image_token_id) local_max_num_images = min(image_count, self.max_num_images) current_images = images[:local_max_num_images] padded_image_tensor = torch.zeros(self.max_num_images, *current_images.size()[1:]) padded_image_tensor[: current_images.size(0)] = current_images output_images.append(padded_image_tensor) output_input_ids.append(torch.tensor(padded_input_ids)) output_attention_masks.append(attention_mask) output_input_ids = torch.stack(output_input_ids) output_images = torch.stack(output_images) output_attention_masks = torch.stack(output_attention_masks) example_ids: List[int] = exs["id"] return { "example_ids": example_ids, "input_ids": [input_ids for input_ids in output_input_ids], "attention_mask": [attention_masks for attention_masks in output_attention_masks], "pixel_values": [pixels for pixels in output_images], } def prepare_dataset(self, exs: Dict, **kwargs) -> Dict: """ Prepare batch of examples. """ num_shots: int = kwargs["num_shots"] if num_shots != 0: raise ValueError( f"Invalid num_shots selection: num_shots should equal 0 for perplexity but here num_shots={num_shots}" ) if self.ds_type == DatasetTypes.IMAGE_CAPTION_PAIRS: # We have a image-caption pair dataset return self.prepare_image_caption_pair_ds(exs) elif self.ds_type == DatasetTypes.WEB_DOCUMENTS: # We have a webdoc dataset return self.prepare_webdoc_ds(exs) else: raise ValueError(f"Invalid dataset type: {self.ds_type}") def get_perplexities(self, **kwargs): model = kwargs["model"] input_ids = torch.stack(kwargs["input_ids"]).to(model.device) attention_mask = torch.stack(kwargs["attention_mask"]).to(model.device) pixel_values = torch.stack(kwargs["pixel_values"]).to(model.device) unwrapped_model = extract_model_from_parallel(model) is_deepspeed_model = isinstance(model, DeepSpeedEngine) if is_deepspeed_model: if model.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in model.module.modules(): module._parameters._in_forward = True pass unwrapped_model.eval() with torch.no_grad(): logits = unwrapped_model( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=input_ids, )["logits"] shift_logits = logits[..., :-1, :].contiguous().float() shift_labels = input_ids[..., 1:].contiguous() shift_attention_mask = attention_mask[..., 1:] shift_attention_mask[shift_labels == self.image_token_id] = 0 loss_fct = torch.nn.CrossEntropyLoss(reduction="none") non_ignored_token_count = torch.sum(shift_attention_mask == 1, dim=1, keepdim=True).flatten() mask = (shift_labels.view(-1) != self.image_token_id) & (shift_labels.view(-1) != self.pad_token_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[mask], shift_labels.view(-1)[mask]) loss = loss.reshape(logits.shape[0], -1).sum(dim=1) loss = loss / non_ignored_token_count perplexities = loss.exp() return perplexities def add_batch_metric(self, metric, **kwargs): perplexities = self.get_perplexities(**kwargs) metric.add_batch( perplexities=perplexities, **{key: kwargs[key] for key in self.target_keys}, ) return metric class TextCapsVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/TextCaps" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "validation" image_column_name: str = "image" text_column_name: str = "reference_strs" class TextCapsSampleVgpt2PerplexityInContext(TextCapsVgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/TextCaps-Sample" class CommonGenVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/common_gen" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "validation" image_column_name: str = "image" context_column_name: str = "concepts" text_column_name: str = "target" def _create_image_caption_pair_prompt(self, caption="", context=""): return ( f"{self.token_around_image}{self.image_token}{self.token_around_image}Input: {context}. Output: {caption}" ) class NoCapsVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/NoCaps" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "validation" # This does not exist yet... it would require adding a training split to the dataset (see `create_sample_evaluation_datasets_simplified.py`) image_column_name: str = "image" text_column_name: str = "annotations_captions" class NoCapsSampleVgpt2PerplexityInContext(NoCapsVgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/NoCaps-Sample" class CocoVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/COCO" dataset_config = "2014_captions" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "validation" image_column_name: str = "image" text_column_name: str = "sentences_raw" class CocoSampleVgpt2PerplexityInContext(CocoVgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/COCO-2014_captions-Sample" dataset_config = None class IIIT5KVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/IIIT-5K" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "test" image_column_name: str = "image" text_column_name: str = "label" def _create_image_caption_pair_prompt(self, caption="", context=None): if context is not None: raise NotImplementedError("Context not implemented for this task") return ( f"{self.token_around_image}{self.image_token}{self.token_around_image}A photo where" f" it is written {caption}" ) class IIIT5KSampleVgpt2PerplexityInContext(IIIT5KVgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/IIIT-5K-Sample" class MiniGPTCaptionsVgpt2PerplexityInContext(Vgpt2PerplexityInContext): dataset_name: str = "HuggingFaceM4/mini-GPT-captions" metric_name: str = "PerplexityMetrics" metric_kwargs = {"metrics": [MetricsPerplexity.PERPLEXITY]} default_query_split_name: str = "train" image_column_name: str = "image" text_column_name: str = "reference_strs"