in data/collators.py [0:0]
def prepare_batch(self, batch, max_length=None):
# batch is a list of dicts, each containing "input_ids", "attention_mask", "labels", "images"
# let's convert it to a dict of lists of tensors
batch = {k: [item[k] for item in batch] for k in batch[0]}
if max_length is not None:
batch = self._discard_samples_that_are_too_long(batch, max_length)
# Pad samples to max length
if max_length is not None:
max_len = max_length
else:
max_len = max(map(len, batch["input_ids"]))
self._pad_batch(batch, max_len) # dictionaries in Python are mutable and passed by reference
return {
"input_ids": torch.stack(batch["input_ids"]),
"attention_mask": torch.stack(batch["attention_mask"]),
"images": batch["images"],
"labels": torch.stack(batch["labels"]),
}