def format_intersection_data()

in scripts/ft_gemma3n_image_trl.py [0:0]


def format_intersection_data(samples: dict) -> dict[str, list]:
    """Format intersection dataset to match expected message format"""
    formatted_samples = {"messages": []}
    for idx in range(len(samples["image"])):
        image = samples["image"][idx].convert("RGB")
        label = str(samples["label"][idx])

        message = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": "You are an assistant with great geometry skills.",
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {
                        "type": "text",
                        "text": "How many intersection points are there in the image?",
                    },
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": label}]},
        ]
        formatted_samples["messages"].append(message)
    return formatted_samples