in scripts/ft_gemma3n_image_vt.py [0:0]
def run_inference(val_dataset, processor, model, fname):
# infer before training
val_sample = random.choice(val_dataset)
image = val_sample["image"].convert("RGB")
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
]
inputs = processor.apply_chat_template(
message,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
plt.imshow(image)
plt.axis("off")
plt.title(f"Pred: {decoded}")
plt.show()
plt.savefig(f"outputs_fine_tune/{fname}")