notebooks/fine_tune_gemma3n_on_audio.ipynb (316 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "GH-ZxsoMoYFj"
},
"source": [
"# Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o-YgQUwQoYFq"
},
"outputs": [],
"source": [
"!pip install -U -q timm transformers trl peft"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2SPZakR-oYFs"
},
"outputs": [],
"source": [
"import io\n",
"import os\n",
"import zipfile\n",
"\n",
"import torch\n",
"from datasets import DatasetDict, load_dataset, Audio\n",
"from huggingface_hub import hf_hub_download, list_repo_files\n",
"from PIL import Image\n",
"from transformers import AutoModelForImageTextToText, AutoProcessor, Gemma3nForConditionalGeneration\n",
"from peft import LoraConfig\n",
"\n",
"from trl import (\n",
" ModelConfig,\n",
" ScriptArguments,\n",
" SFTConfig,\n",
" SFTTrainer,\n",
" TrlParser,\n",
" get_kbit_device_map,\n",
" get_quantization_config,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0oiVkjlNoYFu"
},
"source": [
"# Dataset\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G41ldwfUoYFv"
},
"outputs": [],
"source": [
"dataset = load_dataset(\"AdrienB134/Emilia-dataset-french-split\", split=\"fr\")\n",
"dataset = dataset.select(range(1000))\n",
"dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16_000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Hj3nSwd8oYFw"
},
"outputs": [],
"source": [
"def format_intersection_data(samples: dict) -> dict[str, list]:\n",
" \"\"\"Format intersection dataset to match expected message format\"\"\"\n",
" formatted_samples = {\"messages\": []}\n",
" for idx in range(len(samples[\"audio\"])):\n",
" audio = samples[\"audio\"][idx][\"array\"]\n",
" label = str(samples[\"text\"][idx])\n",
"\n",
" message = [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": [\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": \"You are an assistant that transcribes speech accurately.\",\n",
" }\n",
" ],\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": [\n",
" {\"type\": \"audio\", \"audio\": audio},\n",
" {\"type\": \"text\", \"text\": \"Please transcribe this audio.\"}\n",
" ]\n",
" },\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\":[{\"type\": \"text\", \"text\": label}]\n",
" }\n",
" ]\n",
" formatted_samples[\"messages\"].append(message)\n",
" return formatted_samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KBWjj21VoYFy"
},
"outputs": [],
"source": [
"dataset = dataset.map(format_intersection_data, batched=True, batch_size=4, num_proc=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eJbEdl8xoYFz"
},
"outputs": [],
"source": [
"model = Gemma3nForConditionalGeneration.from_pretrained(\n",
" \"google/gemma-3n-E2B-it\", trust_remote_code=True, torch_dtype=torch.bfloat16,\n",
")\n",
"processor = AutoProcessor.from_pretrained(\n",
" \"google/gemma-3n-E2B-it\", trust_remote_code=True,\n",
")\n",
"processor.tokenizer.padding_side = \"right\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gQK1C9KeoYFz"
},
"outputs": [],
"source": [
"def collate_fn(examples):\n",
" texts = []\n",
" audios = []\n",
"\n",
" for example in examples:\n",
" # Apply chat template to get text\n",
" text = processor.apply_chat_template(\n",
" example[\"messages\"], tokenize=False, add_generation_prompt=False\n",
" ).strip()\n",
" texts.append(text)\n",
"\n",
" # Extract audios\n",
" audios.append(example[\"audio\"][\"array\"])\n",
"\n",
" # Tokenize the texts and process the images\n",
" batch = processor(\n",
" text=texts, audio=audios, return_tensors=\"pt\", padding=True\n",
" )\n",
"\n",
" # The labels are the input_ids, and we mask the padding tokens in the loss computation\n",
" labels = batch[\"input_ids\"].clone()\n",
"\n",
" # Use Gemma3n specific token masking\n",
" labels[labels == processor.tokenizer.pad_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'image_token_id'):\n",
" labels[labels == processor.tokenizer.image_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'audio_token_id'):\n",
" labels[labels == processor.tokenizer.audio_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'boi_token_id'):\n",
" labels[labels == processor.tokenizer.boi_token_id] = -100\n",
" if hasattr(processor.tokenizer, 'eoi_token_id'):\n",
" labels[labels == processor.tokenizer.eoi_token_id] = -100\n",
"\n",
"\n",
" batch[\"labels\"] = labels\n",
" return batch"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gegZk0p9oYF1"
},
"source": [
"# Training\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jXpCDfvboYF1"
},
"outputs": [],
"source": [
"peft_config = LoraConfig(\n",
" task_type=\"CAUSAL_LM\",\n",
" r=8,\n",
" target_modules=[\"q_proj\", \"v_proj\"],\n",
" lora_alpha=16,\n",
" lora_dropout=0.00,\n",
" bias=\"none\",\n",
" use_rslora=False,\n",
" use_dora=False,\n",
" modules_to_save=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kpeGsYS9oYF2"
},
"outputs": [],
"source": [
"training_args = SFTConfig(\n",
" output_dir=\"gemma-3n-E2B-it-trl-sft\",\n",
" eval_strategy='no',\n",
" per_device_train_batch_size=1,\n",
" per_device_eval_batch_size=8,\n",
" gradient_accumulation_steps=8,\n",
" gradient_checkpointing=True,\n",
" learning_rate=1e-05,\n",
" num_train_epochs=1.0,\n",
" logging_steps=10,\n",
" save_steps=100,\n",
" bf16=True,\n",
" report_to=[\"wandb\"],\n",
" run_name='gemma-3n-E2B-it-trl-sft',\n",
" dataset_kwargs={'skip_prepare_dataset': True},\n",
" remove_unused_columns=False,\n",
" max_seq_length=None,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DrIxl84boYF3"
},
"outputs": [],
"source": [
"split_ds = dataset.train_test_split(test_size=0.1, seed=42)\n",
"train_dataset = split_ds[\"train\"]\n",
"val_dataset = split_ds[\"test\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3bUlSwyIoYF3"
},
"outputs": [],
"source": [
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=collate_fn,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset if training_args.eval_strategy != \"no\" else None,\n",
" processing_class=processor.tokenizer,\n",
" peft_config=peft_config,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dU0r26E8oYF3"
},
"outputs": [],
"source": [
"trainer.train()"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".venv_gemma",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}