data/collators.py (48 lines of code) (raw):
import torch
class BaseCollator(object):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def _pad_batch(self, batch, max_length):
batch["input_ids"] = [torch.nn.functional.pad(ids, (max_length - len(ids), 0), value=self.tokenizer.pad_token_id) for ids in batch["input_ids"]]
batch["labels"] = [torch.nn.functional.pad(labels, (max_length - len(labels), 0), value=self.tokenizer.pad_token_id) for labels in batch["labels"]]
batch["attention_mask"] = [torch.nn.functional.pad(attention_mask, (max_length - len(attention_mask), 0), value=0) for attention_mask in batch["attention_mask"]]
def prepare_batch(self, batch, max_length=None):
# batch is a list of dicts, each containing "input_ids", "attention_mask", "labels", "images"
# let's convert it to a dict of lists of tensors
batch = {k: [item[k] for item in batch] for k in batch[0]}
if max_length is not None:
batch = self._discard_samples_that_are_too_long(batch, max_length)
# Pad samples to max length
if max_length is not None:
max_len = max_length
else:
max_len = max(map(len, batch["input_ids"]))
self._pad_batch(batch, max_len) # dictionaries in Python are mutable and passed by reference
return {
"input_ids": torch.stack(batch["input_ids"]),
"attention_mask": torch.stack(batch["attention_mask"]),
"images": batch["images"],
"labels": torch.stack(batch["labels"]),
}
def _discard_samples_that_are_too_long(self, batch, max_length):
filtered = [
(ids, label, attn, img)
for ids, label, attn, img in zip(batch["input_ids"], batch["labels"], batch["attention_mask"], batch["images"])
if len(ids) <= max_length
]
if not filtered:
return [], [], [], []
batch_token_ids, batch_labels, batch_attentions, batch_images = zip(*filtered)
return {"input_ids": list(batch_token_ids), "labels": list(batch_labels), "attention_mask": list(batch_attentions), "images": list(batch_images)}
class VQACollator(BaseCollator): # Visual Question Answering Collator
def __init__(self, tokenizer, max_length):
self.max_length = max_length
super().__init__(tokenizer)
def _pad_batch(self, batch, max_length): # Reimplementing to use -100 as the pad value for labels, so that it's ignored by the loss
batch["input_ids"] = [torch.nn.functional.pad(ids, (max_length - len(ids), 0), value=self.tokenizer.pad_token_id) for ids in batch["input_ids"]]
batch["labels"] = [torch.nn.functional.pad(labels, (max_length - len(labels), 0), value=-100) for labels in batch["labels"]]
batch["attention_mask"] = [torch.nn.functional.pad(attention_mask, (max_length - len(attention_mask), 0), value=0) for attention_mask in batch["attention_mask"]]
def __call__(self, batch):
batch = self.prepare_batch(batch, max_length=self.max_length)
return batch
class MMStarCollator(BaseCollator): # https://huggingface.co/datasets/Lin-Chen/MMStar
def __call__(self, batch):
batch = self.prepare_batch(batch)
return batch