train_rick/01_sft.ipynb (465 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🧠 Mastering LLM Fine-Tuning with TRL – Supervised Fine-Tuning\n",
"\n",
"Welcome! This notebook is part of a tutorial series where you'll learn how to fine-tune Large Language Models (LLMs) using 🤗 TRL.\n",
"We introduce key concepts, set up the required tools, and use techniques like Supervised Fine-Tuning (SFT) and Group-Relative Policy Optimization (GRPO).\n",
"\n",
"## 📋 Prerequisites\n",
"\n",
"Before you begin, make sure you have the following:\n",
"\n",
"* A working knowledge of Python and PyTorch\n",
"* A basic understanding of machine learning and deep learning concepts\n",
"* Access to a GPU accelerator – this notebook is designed to run with **at least 16GB of GPU memory**, such as what is available for free on [Google Colab](https://colab.research.google.com). Runtime Tab -> Change runtime type -> T4 (GPU).\n",
"* The `trl` library installed – this tutorial has been tested with **TRL version 0.17**\n",
" If you don’t have `trl` installed yet, you can install it by running the following code block:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install trl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* A [Hugging Face account](https://huggingface.co) with a configured access token. If needed, run the following code.\n",
"This will prompt you to enter your Hugging Face access token. You can generate one from your Hugging Face account settings under [Access Tokens](https://huggingface.co/settings/tokens). The token must have `Write access to contents/settings of all repos under your personal namespace`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🔄 Quick Recap of the Last Session\n",
"\n",
"In the previous session, we explored the foundational concepts behind training and fine-tuning Large Language Models (LLMs). Here's a brief summary:\n",
"\n",
"* LLMs operate on sequences of integers known as *tokens*. Given a sequence, they predict the probability distribution of the next token in the sequence.\n",
"* The first phase of training an LLM is called **pretraining**. This involves training the model on a massive corpus of unlabeled text data.\n",
"* The output of pretraining is a **base model**, which has learned general language patterns but isn’t specialized for specific tasks.\n",
"* To adapt the base model for a particular use case—like building a chatbot—we need to **fine-tune** (or post-train) it on a dataset of conversations.\n",
"* Many high-quality conversational datasets are available publicly on the [Hugging Face Hub](https://huggingface.co/datasets).\n",
"* These datasets are often not in a format that's directly usable for training, so **data preprocessing** is usually required.\n",
"\n",
"Now that we're on the same page, let's dive into the next session! First Let's load our base model and tokenizer. \n",
"\n",
"\n",
"We'll be using a different model from the first session: `SmolLM2-360M`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"HuggingFaceTB/SmolLM2-360M\")\n",
"model = AutoModelForCausalLM.from_pretrained(\"HuggingFaceTB/SmolLM2-360M\", device_map=\"auto\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, let's define a chat template and make the necessary modifications to the template and tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tokenizer.chat_template = \"\"\"{{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n",
"{%- for message in messages %}\n",
" {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n",
"{%- endfor %}\n",
"{%- if add_generation_prompt %}\n",
" {{- '<|im_start|>assistant\\n' }}\n",
"{%- endif %}\"\"\"\n",
"tokenizer.eos_token = \"<|im_end|>\"\n",
"model.config.eos_token_id = tokenizer.eos_token_id\n",
"model.generation_config.eos_token_id = tokenizer.eos_token_id"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember, configuring a chat template doesn't make the model capable of chatting.\n",
"It only gives the ability to format inputs in a dialogue structure; the model still needs to be fine-tuned on conversational data to respond like a chatbot."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"pipeline = pipeline(task=\"text-generation\", model=model, tokenizer=tokenizer)\n",
"\n",
"question = \"What does it mean for a matrix to be invertible?\"\n",
"prompt = [{\"role\": \"user\", \"content\": question}]\n",
"pipeline(prompt, max_new_tokens=50)[0][\"generated_text\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make the exercise more engaging, we'll use a custom dataset I've created especially for you. Instead of the usual back-and-forth between a lazy user and a helpful assistant, I've spiced things up with dialogues between Rick and Morty. If this reference doesn't ring a bell - sorry for you!\n",
"\n",
"So what we're going to do is basically train a Rick to respond to Morty.\n",
"\n",
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"dataset = load_dataset(\"qgallouedec/rick-science\", split=\"train\")\n",
"dataset[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly, you can see that the dataset needs to be pre-formatted. So, as before, we write and apply the function to format the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def to_conversation(example):\n",
" return {\n",
" \"messages\": [\n",
" {\"role\": \"user\", \"content\": example[\"question\"]},\n",
" {\"role\": \"assistant\", \"content\": f\"<think>{example['reasoning']}</think> {example['answer']}\"},\n",
" ]\n",
" }\n",
"\n",
"dataset = dataset.map(to_conversation, remove_columns=dataset.column_names)\n",
"dataset[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, we have a model ready for training and a dataset ready to go. This is where TRL comes in—it provides the `SFTTrainer`, a trainer designed to fine-tune the model using the dataset.\n",
"\n",
"**But what is SFT?**\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🕵️ Supervised Fine-Tuning (SFT)\n",
"\n",
"SFT stands for **Supervised Fine-Tuning**. There’s nothing particularly revolutionary about the SFT method. It’s simply training the model to predict the next token in a supervised way. Just like in pretraining, we minimize the cross-entropy loss between the model’s predicted distribution and the actual next token. The key difference is that, in SFT, the model is trained on **curated**, labeled data — often conversational or instruction-following examples."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I'm going to break from the usual approach and show you the training process *before* diving into the explanation. Why?\n",
"\n",
"1. You'll get an immediate feel for what the training actually does.\n",
"2. The training can run in the background while I walk you through what's happening.\n",
"\n",
"Without explanation, here's the training code block."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTTrainer, SFTConfig\n",
"\n",
"args = SFTConfig(\n",
" output_dir=\"data/SmolLM2-360M-Rickified\",\n",
" gradient_checkpointing=True,\n",
" bf16=True,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" max_length=650,\n",
" logging_steps=10,\n",
" run_name=\"SmolLM2-360M-Rickified\",\n",
")\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=args,\n",
" processing_class=tokenizer,\n",
" train_dataset=dataset,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we've trained our model, it's time to share it with the world! To do so, simply push it to the Hugging Face Hub using `push_to_hub`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.push_to_hub(dataset_name=\"qgallouedec/rick-science\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's very important to remember that what training LLM, VRAM is the backbone of the battle. This is especially important in our case, as we’re working with limited compute resources. I designed this notebook to be runnable on the free version of Colab, which only provide a GPUs with 16GB of memory. As a result, we must manage GPU memory carefully to avoid running out. This constraint also presents a valuable opportunity to learn about GPU memory usage and optimization—skills that are essential when training large language models (LLMs).\n",
"\n",
"### 🥛 What consumes GPU memory?\n",
"\n",
"When you profile GPU memory usage during training, you get a chart like this one:\n",
"\n",
"\n",
"\n",
"As shown in the graph, several components occupy GPU memory:\n",
"\n",
"* **Model**: The neural network itself, loaded into memory.\n",
"* **Optimizer**: This usually takes up twice as much memory as the model.\n",
"* **Activations**: These are the intermediate outputs from each layer during the forward pass.\n",
"* **Gradients**: The derivatives of the loss with respect to the activations, used during backpropagation.\n",
"* **Optimizer states**: Temporary variables needed by the optimizer during training.\n",
"\n",
"### 🤏 Use a smaller model\n",
"\n",
"You might have noticed that base models are often released in multiple sizes. Larger models tend to perform better, but they also consume significantly more memory. Choosing a smaller model is the most impactful optimization you can make—it reduces not just the model size, but also the memory used by the optimizer and activations.\n",
"\n",
"For this notebook, we’ll be using `SmolLM2-360M`.\n",
"\n",
"### 📐 Handle the sequence length carefully\n",
"\n",
"The memory required for activations is directly proportional to the number of tokens in a batch. The number of tokens in a batch, in turn, depends on two factors: the *batch size* and the *sequence length*. Doubling the sequence length will double the memory needed for activations, and similarly, doubling the batch size will also double the memory required for activations.\n",
"\n",
"\n",
"\n",
"And remember, whether you hit an Out Of Memory (OOM) error only depends on the longest sequence in the entire dataset. If, at some point during training, you encounter a sequence longer than what the GPU can handle, you'll get an OOM error, and the training will need to be restarted from scratch. So, it's important to control this maximum sequence length, and truncate any sequences that exceed it.\n",
"\n",
"To get a better sense of what could be a reasonable value for the maximum sequence length, you can check the distribution of sequence lengths in your dataset. You can do this by running the following code:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"input_ids = [tokenizer.apply_chat_template(example[\"messages\"]) for example in dataset]\n",
"lens = [len(ids) for ids in input_ids]\n",
"\n",
"plt.hist(lens, bins=50)\n",
"plt.xlabel(\"Length of input_ids\")\n",
"plt.ylabel(\"Number of examples\")\n",
"plt.title(\"Distribution of input_ids length\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Handle the batch size carefully\n",
"\n",
"As explained earlier, batch size directly impacts activation memory. But there's a tradeoff: larger batch sizes generally lead to more stable training.\n",
"\n",
"To balance memory constraints with training quality, you can use gradient accumulation—a technique that simulates a larger batch size by accumulating gradients over multiple smaller batches. For example, all of the following configurations result in the same effective batch size (16):\n",
"\n",
"```python\n",
"from trl import SFTConfig\n",
"\n",
"training_args = SFTConfig(per_device_train_batch_size=16, gradient_accumulation_steps=1, ...) # fast but memory intensive\n",
"training_args = SFTConfig(per_device_train_batch_size=8, gradient_accumulation_steps=2, ...)\n",
"training_args = SFTConfig(per_device_train_batch_size=4, gradient_accumulation_steps=4, ...)\n",
"training_args = SFTConfig(per_device_train_batch_size=2, gradient_accumulation_steps=8, ...)\n",
"training_args = SFTConfig(per_device_train_batch_size=1, gradient_accumulation_steps=16, ...) # slow but memory efficient\n",
"```\n",
"\n",
"Just note: more accumulation steps mean more forward passes per update, which increases training time. When possible, prefer fewer steps with a larger batch size."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"### 🕰 Use gradient checkpointing\n",
"\n",
"Gradient checkpointing is a memory-saving technique that reduces GPU usage during training by selectively storing only some intermediate activations and recomputing the others during backpropagation.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"We won't go into the details of how it works, but remember that it can be a bit slower than the standard approach. However, it’s a great way to save memory, especially when training large models. And it's super easy to enable in the `SFTTrainer`, you just have to set the `gradient_checkpointing` argument to `True` when creating the config:\n",
"\n",
"```python\n",
"from trl import SFTConfig\n",
"\n",
"config = SFTConfig(gradient_checkpointing=True, ...)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 🧊 Use mixed precision\n",
"\n",
"Mixed precision training speeds up training and reduces memory usage by combining 16-bit (`float16` or `bfloat16`) and 32-bit (`float32`) floating-point arithmetic. It keeps most operations in 16-bit to save memory and compute time, while using 32-bit where higher precision is needed (like loss scaling or certain model updates). This technique is especially helpful when training large models, as it allows you to fit larger batches or longer sequences in memory—often without a noticeable drop in model performance.\n",
"\n",
"```python\n",
"from trl import SFTConfig\n",
"\n",
"training_args = SFTConfig(bf16=True, ...)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🛸 Did rickification work?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see with a basic question!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"pipeline = pipeline(task=\"text-generation\", model=\"qgallouedec/SmolLM2-360M-Rickified\")\n",
"\n",
"question = \"What would happen to time for an astronaut traveling near the speed of light?\"\n",
"prompt = [{\"role\": \"user\", \"content\": question}]\n",
"pipeline(prompt, max_new_tokens=400)[0][\"generated_text\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It sounds very much like Rick!\n",
"Let's try with another one, a bit out-of-distribution, to see how the model reacts."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What is Edmonton's average temperature in January?\"\n",
"prompt = [{\"role\": \"user\", \"content\": question}]\n",
"pipeline(prompt, max_new_tokens=400)[0][\"generated_text\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we close this chapter, let's generate a final one, just for fun!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"A ball is thrown vertically upward with an initial speed of 12 m/s. What is its maximum height?\"\n",
"prompt = [{\"role\": \"user\", \"content\": question}]\n",
"pipeline(prompt, max_new_tokens=400)[0][\"generated_text\"]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "trl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}