in src/pixparse/task/task_cruller_finetune_RVLCDIP.py [0:0]
def collate_fn(self, batch):
"""
basic collator for PIL images, as returned by rvlcdip dataloader (among others)
"""
images = [item["image"] for item in batch]
labels = [item["label"] for item in batch]
tokenizer_fn = lambda x: self.tokenizer.trunk(x, #FIXME move this batcher/tokenizer elsewhere
add_special_tokens=False,
return_tensors='pt',
max_length=5,
padding='max_length',
truncation=True).input_ids[0]
labels_tokens = [
tokenizer_fn(self.task_start_token + "<" + self.int2str[label] + "/>" + self.tokenizer.trunk.eos_token)
for label in labels
]
transform = self.image_preprocess_train
images = torch.stack([transform(img) for img in images])
labels = torch.stack(labels_tokens)
targets = torch.stack([self.text_input_to_target(text) for text in labels])
labels = labels[:, :-1]
targets = targets[:, 1:]
return {"image": images, "label": labels, "text_target": targets}