def collate_fn()

in src/pixparse/task/task_cruller_eval_rvlcdip.py [0:0]


    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """
        images = [item['image'] for item in batch if item is not None]
        labels = [item['label'] for item in batch if item is not None]

        if len(images) == 0:
            return None

        transformed_images = [self.safe_image_transform(img) for img in images]
        valid_indices = [i for i, img in enumerate(transformed_images) if img is not None]
        images = torch.stack([transformed_images[i] for i in valid_indices])
        labels = torch.tensor([labels[i] for i in valid_indices], dtype=torch.int64)

        return {'image': images, 'label': labels}