in src/pixparse/task/task_cruller_finetune_docvqa.py [0:0]
def collate_fn(self, batch):
tokenizer_fn = lambda x: self.tokenizer.trunk(
x,
add_special_tokens=False,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
).input_ids[0]
images = [item['image'] for item in batch]
q_and_as = [np.random.choice(item['labels']) for item in batch]
inputs_to_stack = []
for text in q_and_as:
inputs_to_stack.append(tokenizer_fn(
"<s_docvqa>"
+ text
+ self.tokenizer.trunk.eos_token
))
# Check max length here and truncate/pad if needed
# You could enforce it to be under or equal to a specific length
text_inputs = torch.stack(inputs_to_stack)
targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])
transform = self.image_preprocess_train
images = torch.stack([transform(img) for img in images])
text_inputs = text_inputs[:, :-1]
targets = targets[:, 1:]
return {
"image": images,
"label": text_inputs,
"text_target": targets,
}