def collate_fn()

in scripts/ft_gemma3n_audio_vt.py [0:0]


def collate_fn(examples, processor):
    messages = list()
    for sample in examples:
        audio = sample["audio"]["array"]
        label = str(sample["text"])
        message = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": "You are an assistant that transcribes speech accurately.",
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "audio", "audio": audio},
                    {"type": "text", "text": "Please transcribe this audio."},
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": label}]},
        ]
        messages.append(message)

    batch = processor.apply_chat_template(
        messages,
        add_generation_prompt=False,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )

    labels = batch["input_ids"].clone()  # Clone input IDs for labels
    # Mask the tokens that we do not want to include in the loss computation
    # -100 is ignored during categorical cross entropy loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == processor.tokenizer.audio_token_id] = -100
    labels[labels == processor.tokenizer.image_token_id] = -100
    labels[labels == processor.tokenizer.boi_token_id] = -100
    labels[labels == processor.tokenizer.eoi_token_id] = -100

    batch["labels"] = labels

    return batch