notebooks/fine_tune_gemma3n_on_t4.ipynb (314 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "WiHDWm7Aiiw5" }, "source": [ "## Setup and Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BLv-NJRZzHiA" }, "outputs": [], "source": [ "!pip install -U -q timm transformers trl peft" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UxE2vzKsbov0" }, "outputs": [], "source": [ "import io\n", "import os\n", "import zipfile\n", "\n", "import torch\n", "from datasets import load_dataset\n", "from PIL import Image\n", "from transformers import AutoProcessor, Gemma3nForConditionalGeneration\n", "\n", "from trl import (\n", " SFTConfig,\n", " SFTTrainer,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "T06yJvcMiqO6" }, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vXqn4Rs4rWVh" }, "outputs": [], "source": [ "dataset = load_dataset(\"ariG23498/intersection-dataset\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x_e3IjDCzioP" }, "outputs": [], "source": [ "def format_intersection_data(samples: dict) -> dict[str, list]:\n", " \"\"\"Format intersection dataset to match expected message format\"\"\"\n", " formatted_samples = {\"messages\": []}\n", " for idx in range(len(samples[\"image\"])):\n", " image = samples[\"image\"][idx].convert(\"RGB\")\n", " label = str(samples[\"label\"][idx])\n", "\n", " message = [\n", " {\n", " \"role\": \"system\",\n", " \"content\": [\n", " {\n", " \"type\": \"text\",\n", " \"text\": \"You are an assistant with great geometry skills.\",\n", " }\n", " ],\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"image\", \"image\": image},\n", " {\n", " \"type\": \"text\",\n", " \"text\": \"How many intersection points are there in the image?\",\n", " },\n", " ],\n", " },\n", " {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": label}]},\n", " ]\n", " formatted_samples[\"messages\"].append(message)\n", " return formatted_samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UQaaLBCVzXH-" }, "outputs": [], "source": [ "dataset = dataset.map(format_intersection_data, batched=True, batch_size=4, num_proc=4)\n", "\n", "model = Gemma3nForConditionalGeneration.from_pretrained(\n", " \"google/gemma-3n-E2B-it\", torch_dtype=torch.bfloat16,\n", ")\n", "processor = AutoProcessor.from_pretrained(\n", " \"google/gemma-3n-E2B-it\",\n", ")\n", "processor.tokenizer.padding_side = \"right\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O1eY8CO03hcQ" }, "outputs": [], "source": [ "def process_vision_info(messages: list[dict]) -> list[Image.Image]:\n", " image_inputs = []\n", " for msg in messages:\n", " content = msg.get(\"content\", [])\n", " if not isinstance(content, list):\n", " content = [content]\n", "\n", " for element in content:\n", " if isinstance(element, dict) and (\"image\" in element or element.get(\"type\") == \"image\"):\n", " if \"image\" in element:\n", " image = element[\"image\"]\n", " else:\n", " image = element\n", " if image is not None:\n", " # Handle dictionary with bytes\n", " if isinstance(image, dict) and \"bytes\" in image:\n", " pil_image = Image.open(io.BytesIO(image[\"bytes\"]))\n", " image_inputs.append(pil_image.convert(\"RGB\"))\n", " # Handle PIL Image objects\n", " elif hasattr(image, \"convert\"):\n", " image_inputs.append(image.convert(\"RGB\"))\n", " return image_inputs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "Sxnlep-S3KGC" }, "outputs": [], "source": [ "def collate_fn(examples):\n", " texts = []\n", " images_list = []\n", "\n", " for example in examples:\n", " # Apply chat template to get text\n", " text = processor.apply_chat_template(\n", " example[\"messages\"], tokenize=False, add_generation_prompt=False\n", " ).strip()\n", " texts.append(text)\n", "\n", " # Extract images\n", " if \"images\" in example: # single-image case\n", " images = [img.convert(\"RGB\") for img in example[\"images\"]]\n", " else: # multi-image case or intersection dataset\n", " images = process_vision_info(example[\"messages\"])\n", " images_list.append(images)\n", "\n", " # Tokenize the texts and process the images\n", " batch = processor(\n", " text=texts, images=images_list, return_tensors=\"pt\", padding=True\n", " )\n", "\n", " # The labels are the input_ids, and we mask the padding tokens in the loss computation\n", " labels = batch[\"input_ids\"].clone()\n", "\n", " # Use Gemma3n specific token masking\n", " labels[labels == processor.tokenizer.pad_token_id] = -100\n", " if hasattr(processor.tokenizer, 'image_token_id'):\n", " labels[labels == processor.tokenizer.image_token_id] = -100\n", " if hasattr(processor.tokenizer, \"audio_token_id\"):\n", " labels[labels == processor.tokenizer.audio_token_id] = -100\n", " if hasattr(processor.tokenizer, 'boi_token_id'):\n", " labels[labels == processor.tokenizer.boi_token_id] = -100\n", " if hasattr(processor.tokenizer, 'eoi_token_id'):\n", " labels[labels == processor.tokenizer.eoi_token_id] = -100\n", "\n", "\n", " batch[\"labels\"] = labels\n", " return batch" ] }, { "cell_type": "markdown", "metadata": { "id": "wM6OxwNTiyZ1" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uD3W2OO5-1PC" }, "outputs": [], "source": [ "from peft import LoraConfig\n", "peft_config = LoraConfig(\n", " task_type=\"CAUSAL_LM\",\n", " r=16,\n", " target_modules=\"all-linear\",\n", " lora_alpha=32,\n", " lora_dropout=0.05,\n", " bias=\"none\",\n", " use_rslora=False,\n", " use_dora=False,\n", " modules_to_save=None,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zG53iSes76H-" }, "outputs": [], "source": [ "training_args = SFTConfig(\n", " output_dir=\"/content/gemma-3n-E2B-it-trl-sft-intersection\",\n", " eval_strategy='no',\n", " per_device_train_batch_size=1,\n", " per_device_eval_batch_size=8,\n", " gradient_accumulation_steps=8,\n", " gradient_checkpointing=True,\n", " learning_rate=1e-05,\n", " num_train_epochs=1.0,\n", " logging_steps=10,\n", " save_steps=100,\n", " bf16=True,\n", " report_to=[\"wandb\"],\n", " run_name='gemma-3n-E2B-it-trl-sft-intersection',\n", " dataset_kwargs={'skip_prepare_dataset': True},\n", " remove_unused_columns=False,\n", " max_seq_length=None,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hPaplK2u70D9" }, "outputs": [], "source": [ "trainer = SFTTrainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=collate_fn,\n", " train_dataset=dataset[\"train\"],\n", " eval_dataset=dataset[\"validation\"] if training_args.eval_strategy != \"no\" else None,\n", " processing_class=processor.tokenizer,\n", " peft_config=peft_config,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gsBJcyqe8ET1" }, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }