in src/pixparse/task/task_cruller_eval_cord.py [0:0]
def collate_fn(self, batch):
"""
basic collator for PIL images, as returned by rvlcdip dataloader (among others)
"""
# TODO move this to a __getitem__ for pickling
tokenizer_fn = lambda x: self.tokenizer.trunk(
x,
add_special_tokens=False,
return_tensors="pt",
max_length=512,
padding="max_length",
truncation=True,
).input_ids[0]
images = [item["image"] for item in batch]
raw_texts = [literal_eval(item["ground_truth"])["gt_parse"] for item in batch]
inputs_to_stack = []
for text in raw_texts:
tokens_from_json, _ = json2token(text, self.tokenizer.trunk.all_special_tokens, sort_json_key=False)
inputs_to_stack.append(tokenizer_fn(
self.task_start_token
#+ self.tokenizer.trunk.bos_token
+ tokens_from_json
+ self.tokenizer.trunk.eos_token
))
text_inputs = torch.stack(
inputs_to_stack
)
targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])
transform = self.image_preprocess_eval
images = torch.stack([transform(img) for img in images])
text_inputs = text_inputs[:, :-1]
targets = targets[:, 1:]
return {
"image": images,
"label": text_inputs,
"text_target": targets,
}