in scripts/ft_gemma3n_image_vt.py [0:0]
def collate_fn(examples, processor):
messages = list()
for sample in examples:
image = sample["image"].convert("RGB")
label = str(sample["label"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch