in scripts/ft_gemma3n_audio_vt.py [0:0]
def collate_fn(examples, processor):
messages = list()
for sample in examples:
audio = sample["audio"]["array"]
label = str(sample["text"])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
messages.append(message)
batch = processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask the tokens that we do not want to include in the loss computation
# -100 is ignored during categorical cross entropy loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == processor.tokenizer.audio_token_id] = -100
labels[labels == processor.tokenizer.image_token_id] = -100
labels[labels == processor.tokenizer.boi_token_id] = -100
labels[labels == processor.tokenizer.eoi_token_id] = -100
batch["labels"] = labels
return batch