notebooks/fine_tune_gemma3n_on_t4.ipynb (314 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "WiHDWm7Aiiw5"
},
"source": [
"## Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BLv-NJRZzHiA"
},
"outputs": [],
"source": [
"!pip install -U -q timm transformers trl peft"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UxE2vzKsbov0"
},
"outputs": [],
"source": [
"import io\n",
"import os\n",
"import zipfile\n",
"\n",
"import torch\n",
"from datasets import load_dataset\n",
"from PIL import Image\n",
"from transformers import AutoProcessor, Gemma3nForConditionalGeneration\n",
"\n",
"from trl import (\n",
" SFTConfig,\n",
" SFTTrainer,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T06yJvcMiqO6"
},
"source": [
"## Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vXqn4Rs4rWVh"
},
"outputs": [],
"source": [
"dataset = load_dataset(\"ariG23498/intersection-dataset\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x_e3IjDCzioP"
},
"outputs": [],
"source": [
"def format_intersection_data(samples: dict) -> dict[str, list]:\n",
" \"\"\"Format intersection dataset to match expected message format\"\"\"\n",
" formatted_samples = {\"messages\": []}\n",
" for idx in range(len(samples[\"image\"])):\n",
" image = samples[\"image\"][idx].convert(\"RGB\")\n",
" label = str(samples[\"label\"][idx])\n",
"\n",
" message = [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": \"You are an assistant with great geometry skills.\",\n",
" }\n",
" ],\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"image\", \"image\": image},\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": \"How many intersection points are there in the image?\",\n",
" },\n",
" ],\n",
" },\n",
" {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": label}]},\n",
" ]\n",
" formatted_samples[\"messages\"].append(message)\n",
" return formatted_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UQaaLBCVzXH-"
},
"outputs": [],
"source": [
"dataset = dataset.map(format_intersection_data, batched=True, batch_size=4, num_proc=4)\n",
"\n",
"model = Gemma3nForConditionalGeneration.from_pretrained(\n",
" \"google/gemma-3n-E2B-it\", torch_dtype=torch.bfloat16,\n",
")\n",
"processor = AutoProcessor.from_pretrained(\n",
" \"google/gemma-3n-E2B-it\",\n",
")\n",
"processor.tokenizer.padding_side = \"right\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O1eY8CO03hcQ"
},
"outputs": [],
"source": [
"def process_vision_info(messages: list[dict]) -> list[Image.Image]:\n",
" image_inputs = []\n",
" for msg in messages:\n",
" content = msg.get(\"content\", [])\n",
" if not isinstance(content, list):\n",
" content = [content]\n",
"\n",
" for element in content:\n",
" if isinstance(element, dict) and (\"image\" in element or element.get(\"type\") == \"image\"):\n",
" if \"image\" in element:\n",
" image = element[\"image\"]\n",
" else:\n",
" image = element\n",
" if image is not None:\n",
" # Handle dictionary with bytes\n",
" if isinstance(image, dict) and \"bytes\" in image:\n",
" pil_image = Image.open(io.BytesIO(image[\"bytes\"]))\n",
" image_inputs.append(pil_image.convert(\"RGB\"))\n",
" # Handle PIL Image objects\n",
" elif hasattr(image, \"convert\"):\n",
" image_inputs.append(image.convert(\"RGB\"))\n",
" return image_inputs"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Sxnlep-S3KGC"
},
"outputs": [],
"source": [
"def collate_fn(examples):\n",
" texts = []\n",
" images_list = []\n",
"\n",
" for example in examples:\n",
" # Apply chat template to get text\n",
" text = processor.apply_chat_template(\n",
" example[\"messages\"], tokenize=False, add_generation_prompt=False\n",
" ).strip()\n",
" texts.append(text)\n",
"\n",
" # Extract images\n",
" if \"images\" in example: # single-image case\n",
" images = [img.convert(\"RGB\") for img in example[\"images\"]]\n",
" else: # multi-image case or intersection dataset\n",
" images = process_vision_info(example[\"messages\"])\n",
" images_list.append(images)\n",
"\n",
" # Tokenize the texts and process the images\n",
" batch = processor(\n",
" text=texts, images=images_list, return_tensors=\"pt\", padding=True\n",
" )\n",
"\n",
" # The labels are the input_ids, and we mask the padding tokens in the loss computation\n",
" labels = batch[\"input_ids\"].clone()\n",
"\n",
" # Use Gemma3n specific token masking\n",
" labels[labels == processor.tokenizer.pad_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'image_token_id'):\n",
" labels[labels == processor.tokenizer.image_token_id] = -100\n",
" if hasattr(processor.tokenizer, \"audio_token_id\"):\n",
" labels[labels == processor.tokenizer.audio_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'boi_token_id'):\n",
" labels[labels == processor.tokenizer.boi_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'eoi_token_id'):\n",
" labels[labels == processor.tokenizer.eoi_token_id] = -100\n",
"\n",
"\n",
" batch[\"labels\"] = labels\n",
" return batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wM6OxwNTiyZ1"
},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uD3W2OO5-1PC"
},
"outputs": [],
"source": [
"from peft import LoraConfig\n",
"peft_config = LoraConfig(\n",
" task_type=\"CAUSAL_LM\",\n",
" r=16,\n",
" target_modules=\"all-linear\",\n",
" lora_alpha=32,\n",
" lora_dropout=0.05,\n",
" bias=\"none\",\n",
" use_rslora=False,\n",
" use_dora=False,\n",
" modules_to_save=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zG53iSes76H-"
},
"outputs": [],
"source": [
"training_args = SFTConfig(\n",
" output_dir=\"/content/gemma-3n-E2B-it-trl-sft-intersection\",\n",
" eval_strategy='no',\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=8,\n",
" gradient_accumulation_steps=8,\n",
" gradient_checkpointing=True,\n",
" learning_rate=1e-05,\n",
" num_train_epochs=1.0,\n",
" logging_steps=10,\n",
" save_steps=100,\n",
" bf16=True,\n",
" report_to=[\"wandb\"],\n",
" run_name='gemma-3n-E2B-it-trl-sft-intersection',\n",
" dataset_kwargs={'skip_prepare_dataset': True},\n",
" remove_unused_columns=False,\n",
" max_seq_length=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hPaplK2u70D9"
},
"outputs": [],
"source": [
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=collate_fn,\n",
" train_dataset=dataset[\"train\"],\n",
" eval_dataset=dataset[\"validation\"] if training_args.eval_strategy != \"no\" else None,\n",
" processing_class=processor.tokenizer,\n",
" peft_config=peft_config,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gsBJcyqe8ET1"
},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}