data/datasets.py (100 lines of code) (raw):

import torch from PIL import Image from torch.utils.data import Dataset class BaseDataset(Dataset): def __init__(self, dataset, tokenizer, image_processor, mp_image_token_length): self.dataset = dataset self.tokenizer = tokenizer self.image_processor = image_processor self.mp_image_token_length = mp_image_token_length self.prefix_len = self._get_prefix_len() def __len__(self): return len(self.dataset) def _get_prefix_len(self): random_string_5_letters = "xzyvd" random_string_chat_templated = self.tokenizer.apply_chat_template([{"role": "assistant", "content": random_string_5_letters}], tokenize=False, add_special_tokens=False) random_string_location = random_string_chat_templated.find(random_string_5_letters) return len(self.tokenizer.encode(random_string_chat_templated[:random_string_location])) def _get_messages(self, item, image_count=0): messages = [] for text in item['texts']: messages.append({"role": "user", "content": text['user']}) messages.append({"role": "assistant", "content": text['assistant']}) if image_count > 0: messages[0]["content"] = self.tokenizer.image_token * image_count * self.mp_image_token_length + messages[0]["content"] return messages def _process_images(self, images): processed_images = [] for image in images: if isinstance(image, Image.Image): if image.mode != 'RGB': image = image.convert('RGB') processed_image = self.image_processor(image) processed_images.append(processed_image) else: raise ValueError("Error processing image") return processed_images def _prepare_inputs_and_loss_mask(self, messages): conv_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, add_special_tokens=False, return_dict=True, ) mask = [0] * len(conv_ids["input_ids"]) # Locate each assistant turn and flip its mask to 1 cursor = 0 for msg in messages: segment_ids = self.tokenizer.apply_chat_template( [msg], tokenize=True, add_special_tokens=False ) seg_len = len(segment_ids) if msg["role"] == "assistant": start = cursor + self.prefix_len end = cursor + seg_len mask[start:end] = [1] * (end - start) # attend to these tokens cursor += seg_len return torch.tensor(conv_ids["input_ids"]), torch.tensor(mask).to(torch.bool), torch.tensor(conv_ids["attention_mask"]) class VQADataset(BaseDataset): # Visual Question Answering Dataset def __getitem__(self, idx): item = self.dataset[idx] # Handle images (should be a list) images_data = item['images'] if not isinstance(images_data, list): images_data = [images_data] # Now process the images processed_images = self._process_images(images_data) messages = self._get_messages(item, len(processed_images)) input_ids, mask, attention_mask = self._prepare_inputs_and_loss_mask(messages) labels = self._get_labels(input_ids, mask) return { "images": processed_images, "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _get_labels(self, input_ids, mask): labels = input_ids.clone().masked_fill(~mask, -100) labels = labels.roll(-1) # Shift labels for causal LM labels[-1] = -100 # Last token has no target return labels class MMStarDataset(BaseDataset): # https://huggingface.co/datasets/Lin-Chen/MMStar def __getitem__(self, idx): item = self.dataset[idx] image = item['image'] processed_images = self._process_images([image]) item['texts'] = [{ "user": item['question'] + "\nAnswer only with the letter!", "assistant": item['answer'] }] messages = self._get_messages(item, image_count=len(processed_images)) input_ids, mask, attention_mask = self._prepare_inputs_and_loss_mask(messages) labels = self._get_labels(input_ids, mask) input_ids = input_ids.masked_fill(mask, self.tokenizer.pad_token_id) attention_mask = attention_mask.masked_fill(mask, 0) return { "images": processed_images, "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } def _get_labels(self, input_ids, mask): labels = input_ids.clone().masked_fill(~mask, self.tokenizer.pad_token_id) return labels