in src/open-r1-multimodal/src/open_r1/sft.py [0:0]
def collate_fn(examples):
texts = [
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
for example in examples
]
image_inputs = []
for example in examples:
imgs, vids = process_vision_info(example["messages"])
image_inputs.append(imgs)
batch = processor(
text=texts,
images=image_inputs,
return_tensors="pt",
padding=True,
)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch