def collate_fn()

in src/pixparse/task/task_cruller_finetune_RVLCDIP.py [0:0]


    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """
        images = [item["image"] for item in batch]
        labels = [item["label"] for item in batch]

        tokenizer_fn = lambda x: self.tokenizer.trunk(x, #FIXME move this batcher/tokenizer elsewhere
            add_special_tokens=False,
            return_tensors='pt',
            max_length=5,
            padding='max_length',
            truncation=True).input_ids[0]

        labels_tokens = [ 
            tokenizer_fn(self.task_start_token + "<" + self.int2str[label] + "/>" + self.tokenizer.trunk.eos_token)
            for label in labels
        ]
        transform = self.image_preprocess_train

        images = torch.stack([transform(img) for img in images])
        labels = torch.stack(labels_tokens)
        targets = torch.stack([self.text_input_to_target(text) for text in labels])

        labels = labels[:, :-1]
        targets = targets[:, 1:]

        return {"image": images, "label": labels, "text_target": targets}