train_rick/00_llm_pretraining_and_data_preparation.ipynb (654 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🧠 Mastering LLM Fine-Tuning with TRL – Introduction and Prerequisites\n", "\n", "Welcome! This notebook is part of a tutorial series where you'll learn how to fine-tune Large Language Models (LLMs) using πŸ€— TRL.\n", "We introduce key concepts, set up the required tools, and use techniques like Supervised Fine-Tuning (SFT) and Group-Relative Policy Optimization (GRPO).\n", "\n", "## πŸ“‹ Prerequisites\n", "\n", "Before you begin, make sure you have the following:\n", "\n", "* A working knowledge of Python and PyTorch\n", "* A basic understanding of machine learning and deep learning concepts\n", "* Access to a GPU accelerator – this notebook is designed to run with **at least 16GB of GPU memory**, such as what is available for free on [Google Colab](https://colab.research.google.com). Runtime Tab -> Change runtime type -> T4 (GPU).\n", "\n", "* The `trl` library installed – this tutorial has been tested with **TRL version 0.17**\n", " If you don’t have `trl` installed yet, you can install it by running the following code block:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install trl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* A [Hugging Face account](https://huggingface.co) with a configured access token. If needed, run the following code.\n", "This will prompt you to enter your Hugging Face access token. You can generate one from your Hugging Face account settings under [Access Tokens](https://huggingface.co/settings/tokens). The token must have `Write access to contents/settings of all repos under your personal namespace`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## πŸ€” Do you remember how LLMs work?\n", "\n", "LLMs are essentially highly advanced autocomplete systems.\n", "You provide them with a bit of text, and they predict what comes next." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "\n", "pipeline = pipeline(task=\"text-generation\", model=\"Qwen/Qwen2.5-1.5B\")\n", "prompt = \"Octopuses have three\"\n", "pipeline(prompt, max_new_tokens=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's right, octopuses have three hearts πŸ«€! I didn't know before I wrote this notebook to be honest.\n", "\n", "The problem with `pipeline` is that it hides the underlying details that are essential to understand before proceeding. Let's break down the pipeline to examine what is happening behind the scenes.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### πŸͺ™ Tokenization\n", "\n", "The first step is making sure the model can understand the text. This is done by transforming the text into tokens. Tokens are small units of text that the model can interpret. \n", "The tokenizer is responsible for encoding the text into token ids and decoding the token ids back into text.\n", "Let's use the same example as before:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-1.5B\")\n", "\n", "prompt = \"Octopuses have three\"\n", "inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n", "inputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is the token mapping:\n", "\n", "| Text | `Oct` | `op` | `uses` | `␣have` | `␣three` |\n", "|--------|--------|------|--------|---------|----------|\n", "| Tokens | `18053`| `453`| `4776`| `614`| `2326`|" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ⏩ The Forward Pass\n", "\n", "Now that we have a list of integers, we can pass them to the model. Let's see what happens when we do that." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen2.5-1.5B\").to(\"cuda\")\n", "output = model(**inputs)\n", "output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I seems not to output text πŸ₯Ί. Let's see what it returns." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output is a `CausalLMOutputWithPast` object, which contains several attributes. The only one we are concerned with for now is `logits`. To understand what this is, let's first check its shape." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output.logits.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The logits tensor is 3-dimensional:\n", "- `batch_size`: 1 (since we only have one sequence)\n", "- `sequence_length`: 5 (because the text was tokenized into 5 tokens, as seen earlier)\n", "- `vocab_size`: 151936 (the total number of unique tokens the tokenizer can map to an integer-varies depending on the tokenizer).\n", "\n", "**But what are these logits?**\n", "\n", "Logits are the scores that the model assigns to each token in the vocabulary for the next position in the sequence. They represent the model's confidence in predicting each token as the next one in the sequence.\n", "Conceptually, you can think of logits as a probability distribution over the vocabulary. The model is saying, \"Given the input sequence, here are my scores for each possible next token.\"\n", "\n", "Consequently, the last column of the logits tensor corresponds to the model's prediction for the next token in the sequence. In other word, what comes after `\"Octopuses have three\"`? To get a better unserstanding, let's try to plot the distribution of the logits for the last token in the sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# Take only the logits of the last token\n", "last_logits = output.logits[0, -1, :] # shape = (151936,)\n", "last_probs = torch.softmax(last_logits, dim=-1) # turn logits into probabilities\n", "\n", "# Let's consider only the 10 most probable tokens\n", "top_last_probs, top_last_ids = torch.topk(last_probs, k=10)\n", "\n", "print(f\"The most probable next token ids are: {top_last_ids.tolist()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If models understand tokens ids, I don't. So let's decode these ids and see what they mean." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "top_last_tokens = tokenizer.batch_decode(top_last_ids)\n", "print(f\"The most probable next tokens are: {top_last_tokens}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot the distribution to better understand what the model is predicting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "plt.bar(top_last_tokens, top_last_probs.tolist())\n", "\n", "plt.xlabel('Next token')\n", "plt.xticks(rotation=45)\n", "plt.ylabel('Logit')\n", "plt.title('Octopuses have three...')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've done it for the last token in the sequence, but we can do it for all tokens in the sequence. Let's see what the model is predicting for each token in the sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "probs = torch.softmax(output.logits[0], dim=-1) # turn logits into probabilities, shape = (5, 151936)\n", "# Let's consider only the 10 most probable tokens\n", "top_probs, top_ids = torch.topk(probs, k=10)\n", "top_tokens = [tokenizer.batch_decode(ids) for ids in top_ids]\n", "# Replace \" \" with \"␣\" for better visualization\n", "top_tokens = [[token.replace(\" \", \"␣\") for token in tokens] for tokens in top_tokens]\n", "\n", "fig, ax = plt.subplots(1, 5, figsize=(16, 3))\n", "for i in range(5):\n", " ax[i].bar(tokenizer.batch_decode(top_ids[i]), top_probs[i].tolist())\n", " ax[i].set_xticks(range(10))\n", " ax[i].set_xticklabels(top_tokens[i], rotation=45)\n", " ax[i].set_xlabel('Next token')\n", " ax[i].set_ylabel('Probability')\n", " partial_seq = tokenizer.decode(inputs['input_ids'][0][:i + 1])\n", " ax[i].set_title(f\"{partial_seq}...\")\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pretty interesting, right? At each step, we can see what the model thinks as the most likely next token." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![image.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tuto_forward_pass.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## πŸ” How do we end up with a model capable of outputting such distributions?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LLMs don’t start out smart — far from it. The impressive ability to generate coherent, relevant text comes from a process called **pretraining**.\n", "\n", "#### πŸ—οΈ What is Pretraining?\n", "\n", "During pretraining, a model starts with **random weights** and learns by trying to **predict the next token** in massive amounts of text — often hundreds of billions of tokens scraped from the internet. For Llama3, about 15,000 billion tokens were used. This process teaches the model general language patterns, grammar, and world knowledge.\n", "\n", "Pretraining is:\n", "\n", "* **Massive** in scale (weeks or months on hundreds of GPUs)\n", "* **Costly** (millions of dollars)\n", "* **Foundational** — it's what makes an LLM even remotely useful\n", "\n", "Once this phase is complete, we get what’s called a **base model**.\n", "\n", "#### πŸ§ͺ What is a Base Model?\n", "\n", "A base model is pretrained, but it hasn't been taught *how* to behave.\n", "\n", "It doesn't follow instructions well. \n", "It doesn’t know how to have a conversation or answer questions directly. \n", "It simply continues text based on patterns it has seen.\n", "\n", "So if you give it a prompt like:\n", "\n", "> *“What is the capital of Germany?”*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"What is the capital of Germany?\\n\"\n", "print(pipeline(prompt, max_new_tokens=100)[0][\"generated_text\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model generates something that *looks* like a multiple-choice question — because it has seen many of those during pretraining — but it won’t actually answer the question.\n", "Why? Because it hasn’t been trained to respond helpfully. No one has told it: *\"This is how you should respond.\"*\n", "\n", "#### 🧠 What About ChatGPT, Claude, and Others?\n", "\n", "When you use popular models like GPT-4o, Claude, DeepSeek-R1, or o3, you’re *not* using a base model.\n", "\n", "You’re using a model that’s been **fine-tuned** — and often **reinforcement-aligned** — to be helpful, safe, and responsive.\n", "\n", "For example:\n", "\n", "- **DeepSeek-R1** is fine-tuned from a model called **DeepSeek-V3-Base**. \n", "- **OpenAI o4-mini** is a fine-tuned version of an unknown base model. \n", "- **Llama 4 Scout** (officially: *Llama-4-Scout-17B-16E-Instruct*) is a fine-tuned version of *Llama-4-Scout-17B-16E*.\n", "\n", "### 🎯 Why Does This Matter?\n", "\n", "In this tutorial, we’re starting with a **base model** — one that can generate text, but isn’t yet useful on its own.\n", "\n", "It won’t follow instructions well, and it may not be helpful or safe by default.\n", "\n", "Your job is to **fine-tune** it into something smarter, more helpful, or more aligned to your specific goals.\n", "\n", "That’s the magic of **post-training** — and that’s where **TRL** and your creativity come in.\n", "\n", "![image.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tuto_pretraining_posttraining.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "β˜• At this point, I think it's a good time to take a break. Let's grab a coffee and come back in 10 minutes. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## πŸŽ›οΈ Fine-tuning\n", "\n", "In the previous section, we discussed what pretraining is and how it gives us a base model. To recap, a base model is one that has been pretrained on a huge dataset, but hasn’t yet been adapted for specific tasks. It can generate text, but by itself it isn’t particularly useful.\n", "\n", "### πŸ—£οΈ Chat template\n", "\n", "To move beyond simple text completions and start building something remotely helpful, like a chatbot, the first step is to use a conversation template. When you interact with a chatbot, your message isn't passed to the model as plain text. Instead, it's formatted in a structured way that tells the model it's part of a conversation—who's speaking, what has been said before, and so on. This structure is defined using what's called a **chat template**.\n", "\n", "Here is an example of such templated input:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"\"\"<|im_start|>system\n", "You are a helpful assistant.<|im_end|>\n", "<|im_start|>user\n", "How many hearts do octopuses have?<|im_end|>\n", "<|im_start|>assistant\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's important to remember that just using a chat template doesn't automatically make the model behave like a chatbot. This format is likely unfamiliar to the model—it probably hasn't seen much data like this before. But let's try it out and see what happens:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(pipeline(prompt, max_new_tokens=20)[0][\"generated_text\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I don’t even know how to describe that output. But what’s clear is that it’s not satisfactory.\n", "\n", "You may have noticed that I used a special format to represent the conversation. Specifically, I used custom tokens like `<|im_start|>` and `<|im_end|>`. This kind of formatting is very convenient and easy to parse for later use. It's known as a *chat template*. Tokenizers are also capable of handling these templates—as long as you specify the one you want using Jinja2 format:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-1.5B\")\n", "tokenizer.chat_template = \"\"\"{{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n", "{%- for message in messages %}\n", " {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n", "{%- endfor %}\n", "{%- if add_generation_prompt %}\n", " {{- '<|im_start|>assistant\\n' }}\n", "{%- endif %}\"\"\"\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can use it to format conversations properly. For example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\"role\": \"user\", \"content\": \"How many hearts do octopuses have?\"},\n", " {\"role\": \"assistant\", \"content\": \"Octopuses have three hearts.\"},\n", "]\n", "print(tokenizer.apply_chat_template(messages, tokenize=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This doesn’t solve our problem yet—but at least now we have a tool to format conversations properly.\n", "\n", "### πŸ—‚οΈ Conversational Data\n", "\n", "Let’s recap. At this point, we have a model that can generate text, but it’s not yet capable of holding a conversation. We also have a tool to format dialogues using a chat template. So, what’s still missing?\n", "\n", "You guessed it—just look at the section title. What we’re missing is *data*.\n", "\n", "Hugging Face tools have already been incredibly helpful, even if you didn’t notice. First, you loaded the model using the `transformers` library. Then, you loaded the tokenizer the same way. Both the model and tokenizer were automatically downloaded for you from the πŸ€— Hugging Face Hub.\n", "\n", "Here’s the great part: the Hub doesn’t just host models—it also offers a huge collection of datasets. As of writing, there are 381,735! Let’s go check one out.\n", "\n", "Let’s pick a conversational dataset, like [Open-Thoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k). To load it, we’ll use another fantastic library in the Hugging Face ecosystem: `datasets`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"open-thoughts/OpenThoughts-114k\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let’s take a peek at what this dataset contains—starting with the first example." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example = dataset[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The raw output might not be very readable, but you can always explore it visually on the Hugging Face Hub. What’s most important for us is that this dataset contains conversations—in the `conversations` column. So let’s try formatting one of them using our chat template." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer.apply_chat_template(example[\"conversations\"], tokenize=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uh oh! `UndefinedError: 'dict object' has no attribute 'role'`. Looks like the dataset isn’t in the format we expected. Yep, that happens—and it’s actually pretty common.\n", "\n", "Whenever you're training models, you’ll almost always have to go through a data preprocessing step. So let’s tackle that now." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 🧹 Data Preparation\n", "\n", "What we want is for each conversation to look like this:\n", "\n", "```python\n", "{\n", " \"messages\": [\n", " {\"role\": \"user\", \"content\": \"How many hearts do octopuses have?\"},\n", " {\"role\": \"assistant\", \"content\": \"Octopuses have three hearts.\"},\n", " ]\n", "}\n", "```\n", "\n", "But the dataset actually looks like this:\n", "\n", "```python\n", "{\n", " \"conversations\": [\n", " {\"from\": \"human\", \"value\": \"How many hearts do octopuses have?\"},\n", " {\"from\": \"assistant\", \"value\": \"Octopuses have three hearts.\"}\n", " ]\n", "}\n", "```\n", "\n", "Let’s write a function to convert from the second format to the one we need." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def format_example(example):\n", " messages = []\n", " for message in example[\"conversations\"]:\n", " role = message[\"from\"]\n", " content = message[\"value\"]\n", " message = {\"role\": role, \"content\": content}\n", " messages.append(message)\n", " return {\"messages\": messages}\n", "\n", "format_example(example)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Perfect! Now, how do we apply this function to the entire dataset? Simple—just use `dataset.map`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = dataset.map(format_example, remove_columns=\"conversations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That was quick. Let’s now try formatting the first example using our chat template again." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example = dataset[\"train\"][0]\n", "print(tokenizer.apply_chat_template(example[\"messages\"], tokenize=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Voilà! We now have a dataset of properly formatted conversations. It's ready to be used for training our model.\n", "\n", "\n", "> Waita second! We forgot to apply the chat template to the entire dataset!\n", "\n", "No worries, the trainer will take care of that for us. πŸ˜‰" ] } ], "metadata": { "kernelspec": { "display_name": "trl", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }