def format_data()

in scripts/ft_gemma3n_image_trl.py [0:0]


def format_data(samples: dict) -> dict[str, list]:
    formatted_samples = {"messages": []}
    for cont in range(len(samples["question"])):
        images = []
        for img_path in samples["input_image_path"][cont]:
            try:
                with open(img_path, "rb") as f:
                    img_bytes = f.read()
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                images.append({"type": "image", "image": image})
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")
                continue

        formatted_samples["messages"].append(
            [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": samples["context"][cont]}],
                },
                {
                    "role": "user",
                    "content": images
                    + [{"type": "text", "text": samples["question"][cont]}],
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": samples["output"][cont]}],
                },
            ]
        )
    return formatted_samples