def run_inference()

in scripts/ft_gemma3n_audio_vt.py [0:0]


def run_inference(val_dataset, processor, model, fname):
    # infer before training
    val_sample = random.choice(val_dataset)
    audio = val_sample["audio"]["array"]
    message = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are an assistant that transcribes speech accurately.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "audio", "audio": audio},
                {"type": "text", "text": "Please transcribe this audio."},
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        message,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]
    with torch.no_grad():
        generation = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)

    print(f"Audio transcription: {decoded}")
    print(f"Label: {val_sample['text']}")