def run_inference()

in scripts/ft_gemma3n_image_vt.py [0:0]


def run_inference(val_dataset, processor, model, fname):
    # infer before training
    val_sample = random.choice(val_dataset)
    image = val_sample["image"].convert("RGB")
    message = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are an assistant with great geometry skills.",
                }
            ],
        },
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {
                    "type": "text",
                    "text": "How many intersection points are there in the image?",
                },
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        message,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device, dtype=torch.bfloat16)
    input_len = inputs["input_ids"].shape[-1]
    with torch.no_grad():
        generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)

    plt.imshow(image)
    plt.axis("off")
    plt.title(f"Pred: {decoded}")
    plt.show()
    plt.savefig(f"outputs_fine_tune/{fname}")