in scripts/ft_gemma3n_audio_vt.py [0:0]
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
audio = val_sample["audio"]["array"]
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant that transcribes speech accurately.",
}
],
},
{
"role": "user",
"content": [
{"type": "audio", "audio": audio},
{"type": "text", "text": "Please transcribe this audio."},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(f"Audio transcription: {decoded}")
print(f"Label: {val_sample['text']}")