in scripts/ft_gemma3n_image_trl.py [0:0]
def format_intersection_data(samples: dict) -> dict[str, list]:
"""Format intersection dataset to match expected message format"""
formatted_samples = {"messages": []}
for idx in range(len(samples["image"])):
image = samples["image"][idx].convert("RGB")
label = str(samples["label"][idx])
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "You are an assistant with great geometry skills.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image},
{
"type": "text",
"text": "How many intersection points are there in the image?",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": label}]},
]
formatted_samples["messages"].append(message)
return formatted_samples