def collate_fn()

in src/pixparse/task/task_cruller_finetune_docvqa.py [0:0]


    def collate_fn(self, batch):
        tokenizer_fn = lambda x: self.tokenizer.trunk(
            x,
            add_special_tokens=False,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True,
        ).input_ids[0]
        images = [item['image'] for item in batch]
        q_and_as = [np.random.choice(item['labels']) for item in batch]
        
        inputs_to_stack = []
        for text in q_and_as:
            inputs_to_stack.append(tokenizer_fn(
                "<s_docvqa>"
                + text
                + self.tokenizer.trunk.eos_token
            )) 
        # Check max length here and truncate/pad if needed
        # You could enforce it to be under or equal to a specific length

        text_inputs = torch.stack(inputs_to_stack)
        targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])

        transform = self.image_preprocess_train
        images = torch.stack([transform(img) for img in images])
        
        text_inputs = text_inputs[:, :-1]
        targets = targets[:, 1:]

        return {
            "image": images,
            "label": text_inputs,
            "text_target": targets,
        }