def collate_fn()

in src/pixparse/task/task_cruller_eval_cord.py [0:0]


    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """

        # TODO move this to a __getitem__ for pickling
        tokenizer_fn = lambda x: self.tokenizer.trunk(
            x,
            add_special_tokens=False,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True,
        ).input_ids[0]

        images = [item["image"] for item in batch]
        raw_texts = [literal_eval(item["ground_truth"])["gt_parse"] for item in batch]
        inputs_to_stack = []
        for text in raw_texts:
            tokens_from_json, _ = json2token(text, self.tokenizer.trunk.all_special_tokens, sort_json_key=False)
            inputs_to_stack.append(tokenizer_fn(
                self.task_start_token
                #+ self.tokenizer.trunk.bos_token
                + tokens_from_json
                + self.tokenizer.trunk.eos_token
            ))
        text_inputs = torch.stack(
            inputs_to_stack
        )
        targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])

        transform = self.image_preprocess_eval
        images = torch.stack([transform(img) for img in images])
        text_inputs = text_inputs[:, :-1]
        targets = targets[:, 1:]
        return {
            "image": images,
            "label": text_inputs,
            "text_target": targets,
        }