def collate_fn()

in scripts/ft_gemma3n_image_vt.py [0:0]


def collate_fn(examples, processor):
    messages = list()
    for sample in examples:
        image = sample["image"].convert("RGB")
        label = str(sample["label"])
        message = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": "You are an assistant with great geometry skills.",
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {
                        "type": "text",
                        "text": "How many intersection points are there in the image?",
                    },
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": label}]},
        ]
        messages.append(message)

    batch = processor.apply_chat_template(
        messages,
        add_generation_prompt=False,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )

    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    # Mask the tokens that we do not want to include in the loss computation
    # -100 is ignored during categorical cross entropy loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == processor.tokenizer.image_token_id] = -100
    labels[labels == processor.tokenizer.boi_token_id] = -100
    labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels

    return batch