def collate_fn()

in src/pixparse/task/task_cruller_finetune_xent.py [0:0]


    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """
        images = [item['image'] for item in batch]
        labels = [item['label'] for item in batch]
        
        transform = self.image_preprocess_train
        
        images = torch.stack([transform(img) for img in images])
        labels = torch.tensor(labels, dtype=torch.int64)
        return {'image': images, 'label': labels}