notebooks/fine_tune_gemma3n_on_audio.ipynb (316 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "GH-ZxsoMoYFj" }, "source": [ "# Setup and Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "o-YgQUwQoYFq" }, "outputs": [], "source": [ "!pip install -U -q timm transformers trl peft" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2SPZakR-oYFs" }, "outputs": [], "source": [ "import io\n", "import os\n", "import zipfile\n", "\n", "import torch\n", "from datasets import DatasetDict, load_dataset, Audio\n", "from huggingface_hub import hf_hub_download, list_repo_files\n", "from PIL import Image\n", "from transformers import AutoModelForImageTextToText, AutoProcessor, Gemma3nForConditionalGeneration\n", "from peft import LoraConfig\n", "\n", "from trl import (\n", " ModelConfig,\n", " ScriptArguments,\n", " SFTConfig,\n", " SFTTrainer,\n", " TrlParser,\n", " get_kbit_device_map,\n", " get_quantization_config,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "0oiVkjlNoYFu" }, "source": [ "# Dataset\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G41ldwfUoYFv" }, "outputs": [], "source": [ "dataset = load_dataset(\"AdrienB134/Emilia-dataset-french-split\", split=\"fr\")\n", "dataset = dataset.select(range(1000))\n", "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16_000))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Hj3nSwd8oYFw" }, "outputs": [], "source": [ "def format_intersection_data(samples: dict) -> dict[str, list]:\n", " \"\"\"Format intersection dataset to match expected message format\"\"\"\n", " formatted_samples = {\"messages\": []}\n", " for idx in range(len(samples[\"audio\"])):\n", " audio = samples[\"audio\"][idx][\"array\"]\n", " label = str(samples[\"text\"][idx])\n", "\n", " message = [\n", " {\n", " \"role\": \"system\",\n", " \"content\": [\n", " {\n", " \"type\": \"text\",\n", " \"text\": \"You are an assistant that transcribes speech accurately.\",\n", " }\n", " ],\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"audio\", \"audio\": audio},\n", " {\"type\": \"text\", \"text\": \"Please transcribe this audio.\"}\n", " ]\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\":[{\"type\": \"text\", \"text\": label}]\n", " }\n", " ]\n", " formatted_samples[\"messages\"].append(message)\n", " return formatted_samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KBWjj21VoYFy" }, "outputs": [], "source": [ "dataset = dataset.map(format_intersection_data, batched=True, batch_size=4, num_proc=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eJbEdl8xoYFz" }, "outputs": [], "source": [ "model = Gemma3nForConditionalGeneration.from_pretrained(\n", " \"google/gemma-3n-E2B-it\", trust_remote_code=True, torch_dtype=torch.bfloat16,\n", ")\n", "processor = AutoProcessor.from_pretrained(\n", " \"google/gemma-3n-E2B-it\", trust_remote_code=True,\n", ")\n", "processor.tokenizer.padding_side = \"right\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gQK1C9KeoYFz" }, "outputs": [], "source": [ "def collate_fn(examples):\n", " texts = []\n", " audios = []\n", "\n", " for example in examples:\n", " # Apply chat template to get text\n", " text = processor.apply_chat_template(\n", " example[\"messages\"], tokenize=False, add_generation_prompt=False\n", " ).strip()\n", " texts.append(text)\n", "\n", " # Extract audios\n", " audios.append(example[\"audio\"][\"array\"])\n", "\n", " # Tokenize the texts and process the images\n", " batch = processor(\n", " text=texts, audio=audios, return_tensors=\"pt\", padding=True\n", " )\n", "\n", " # The labels are the input_ids, and we mask the padding tokens in the loss computation\n", " labels = batch[\"input_ids\"].clone()\n", "\n", " # Use Gemma3n specific token masking\n", " labels[labels == processor.tokenizer.pad_token_id] = -100\n", " if hasattr(processor.tokenizer, 'image_token_id'):\n", " labels[labels == processor.tokenizer.image_token_id] = -100\n", " if hasattr(processor.tokenizer, 'audio_token_id'):\n", " labels[labels == processor.tokenizer.audio_token_id] = -100\n", " if hasattr(processor.tokenizer, 'boi_token_id'):\n", " labels[labels == processor.tokenizer.boi_token_id] = -100\n", " if hasattr(processor.tokenizer, 'eoi_token_id'):\n", " labels[labels == processor.tokenizer.eoi_token_id] = -100\n", "\n", "\n", " batch[\"labels\"] = labels\n", " return batch" ] }, { "cell_type": "markdown", "metadata": { "id": "gegZk0p9oYF1" }, "source": [ "# Training\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jXpCDfvboYF1" }, "outputs": [], "source": [ "peft_config = LoraConfig(\n", " task_type=\"CAUSAL_LM\",\n", " r=8,\n", " target_modules=[\"q_proj\", \"v_proj\"],\n", " lora_alpha=16,\n", " lora_dropout=0.00,\n", " bias=\"none\",\n", " use_rslora=False,\n", " use_dora=False,\n", " modules_to_save=None,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kpeGsYS9oYF2" }, "outputs": [], "source": [ "training_args = SFTConfig(\n", " output_dir=\"gemma-3n-E2B-it-trl-sft\",\n", " eval_strategy='no',\n", " per_device_train_batch_size=1,\n", " per_device_eval_batch_size=8,\n", " gradient_accumulation_steps=8,\n", " gradient_checkpointing=True,\n", " learning_rate=1e-05,\n", " num_train_epochs=1.0,\n", " logging_steps=10,\n", " save_steps=100,\n", " bf16=True,\n", " report_to=[\"wandb\"],\n", " run_name='gemma-3n-E2B-it-trl-sft',\n", " dataset_kwargs={'skip_prepare_dataset': True},\n", " remove_unused_columns=False,\n", " max_seq_length=None,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DrIxl84boYF3" }, "outputs": [], "source": [ "split_ds = dataset.train_test_split(test_size=0.1, seed=42)\n", "train_dataset = split_ds[\"train\"]\n", "val_dataset = split_ds[\"test\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3bUlSwyIoYF3" }, "outputs": [], "source": [ "trainer = SFTTrainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=collate_fn,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset if training_args.eval_strategy != \"no\" else None,\n", " processing_class=processor.tokenizer,\n", " peft_config=peft_config,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dU0r26E8oYF3" }, "outputs": [], "source": [ "trainer.train()" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": ".venv_gemma", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 0 }