in src/pixparse/task/task_cruller_finetune_xent.py [0:0]
def collate_fn(self, batch):
"""
basic collator for PIL images, as returned by rvlcdip dataloader (among others)
"""
images = [item['image'] for item in batch]
labels = [item['label'] for item in batch]
transform = self.image_preprocess_train
images = torch.stack([transform(img) for img in images])
labels = torch.tensor(labels, dtype=torch.int64)
return {'image': images, 'label': labels}