cookbook-efforts/kto-preference/01_create_preference_task.ipynb (1,210 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Creating a KTO Preference dataset using Argilla and Spaces\n", "\n", "This notebook walks through the steps involved in creating a KTO dataset using Argilla and 🤗 Spaces. This notebook already assumes you are at the point where you already have a dataset with the following dataset consisting of prompts and responses. \n", "Using this data as a starting point we'll setup an Argilla Space which anyone with a Hugging Face account can login to and provide feedback on the responses generated by a model(s). \n", "\n", "In this example we'll focus on a dataset containing prompts and responses focused on generating Haiku. The approach could be applied to any dataset where you want to collect human ratings for a set of prompts and responses. Our end goal is to produce a dataset that can be used with the the [`trl`](https://github.com/huggingface/trl) library `KTOTrainer`. \n", "\n", "The steps we'll cover are:\n", "- Setting up an Argilla Space\n", "- Uploading the dataset to the Space\n", "- Labeling the dataset\n", "- Exporting the labeled dataset\n", "- Formatting the labeled dataset for use with `KTOTrainer`\n", "- Sharing the dataset to the Hub\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are running the notebook on Google Colab you need to install `argilla` " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# %pip install argilla " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import duplicate_space\n", "from huggingface_hub import hf_hub_download\n", "from huggingface_hub import HfApi\n", "from huggingface_hub import SpaceCard\n", "from rich import print" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Create the Argilla Hugging Face Space\n", "\n", "To collect out preference data we'll use Argilla hosted on Hugging Face Spaces. This setup will allow anyone with a Hub account (using oauth authentication) to contribute to the dataset (you can also restrict access to a specific group of people if you want). The first step is to create a Space on Hugging Face Spaces. Before we do this we'll authenticate with the `huggingface_hub` library to make sure we can programmatically interact with Spaces." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Duplicate a template Space\n", "\n", "We'll duplicate an existing Argilla Space template. This will help us get up and running with an Argilla Space quickly. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RepoUrl('https://huggingface.co/spaces/davanstrien/haiku-preferences', endpoint='https://huggingface.co', repo_type='space', repo_id='davanstrien/haiku-preferences')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from_id = \"argilla/argilla-template-space-with-oauth\"\n", "to_id = \"davanstrien/haiku-preferences\"\n", "new_space = duplicate_space(from_id, to_id=to_id)\n", "new_space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We update the tile and description of the Space to reflect the dataset we are creating. Update this to reflect the dataset you are creating. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/spaces/davanstrien/haiku-preferences/commit/00e3a2dbe0d0dd0845bb8e15ee9c2297330df026', commit_message='Upload README.md with huggingface_hub', commit_description='', oid='00e3a2dbe0d0dd0845bb8e15ee9c2297330df026', pr_url=None, pr_revision=None, pr_num=None)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "card = SpaceCard.load(to_id)\n", "card.data.title = f\"DIBT haiku preferences\"\n", "card.push_to_hub(to_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Create an application on the Hub\n", "\n", "To enable the Oauth integration we need to create an application on the Hub. We can do this via the Hugging Face settings UI.\n", "\n", "- Go to this page: [https://huggingface.co/settings/applications/new](https://huggingface.co/settings/applications/new)\n", "- Complete the form to create a new application. You will need to provide the following values:\n", " - Name of application\n", " - Homepage URL: Your Argilla Space Direct URL.\n", " - Logo URL: [Your Argilla Space Direct URL]/favicon.ico\n", " - Scopes: openid and profile.\n", " - Redirect URL: [Your Argilla Space Direct URL]/oauth/huggingface/callback\n", "\n", "The cell below will show you the URL for these values.\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Homepage URL: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://davanstrien-haiku-preferences.hf.space</span> \n", " Logo URL: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://davanstrien-haiku-preferences.hf.space/favicon.ico</span> \n", " Redirect URL: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://davanstrien-haiku-preferences.hf.space/oauth/huggingface/callback</span>\n", "</pre>\n" ], "text/plain": [ "Homepage URL: \u001b[4;94mhttps://davanstrien-haiku-preferences.hf.space\u001b[0m \n", " Logo URL: \u001b[4;94mhttps://davanstrien-haiku-preferences.hf.space/favicon.ico\u001b[0m \n", " Redirect URL: \u001b[4;94mhttps://davanstrien-haiku-preferences.hf.space/oauth/huggingface/callback\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "homepage_url = f\"https://{new_space.repo_id.lower().replace('/', '-')}.hf.space\"\n", "favicon_url = f\"{homepage_url.lower()}/favicon.ico\"\n", "redirect_url = f\"{homepage_url.lower()}/oauth/huggingface/callback\"\n", "print(f\"Homepage URL: {homepage_url.lower()} \\n Logo URL: {favicon_url} \\n Redirect URL: {redirect_url}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Setup your Space secrets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we have created the application we will need to update our Space secrets to add these values which will be shown once you've created your application using the steps above.\n", "\n", "- `OAUTH2_HUGGINGFACE_CLIENT_ID`: [Your Client ID]\n", "- `OAUTH2_HUGGINGFACE_CLIENT_SECRET` : [Your App Secret]\n", "\n", "Additionally, we highly recommend setting up a custom API_KEY and PASSWORD for the owner role (you). This owner role would be the only one allowed to create, delete, read and update datasets, so it's important to change the defaults:\n", "\n", "- `OWNER_API_KEY`: you can put any alphanumeric value\n", "- `OWNER_PASSWORD`: at least 8 digits/characters.\n", "\n", "You can add these secrets via the settings page of your Space. \n", "\n", "![secrets](assets/secrets.png)\n", "\n", "The secrets can be added via the settings tab of your Space." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'https://huggingface.co/spaces/davanstrien/haiku-preferences/settings'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f\"{new_space.url}/settings\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Persistent Storage + Upgrade CPU\n", "\n", "To ensure all annotations are safely stored we'll want to enable persistent storage on our Space. This means that if the Space is stopped and restarted, all annotations will still be available. Additionally, we'll upgrade the CPU and disable sleeping to ensure the Space is always available for annotators!\n", "\n", "![storage](assets/storage.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now need to factory reset the Space to ensure all of the above changes register" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SpaceRuntime(stage='RUNNING_BUILDING', hardware='cpu-basic', requested_hardware='cpu-upgrade', sleep_time=172800, storage='small', raw={'stage': 'RUNNING_BUILDING', 'hardware': {'current': 'cpu-basic', 'requested': 'cpu-upgrade'}, 'storage': 'small', 'gcTimeout': 172800, 'replicas': {'current': 1, 'requested': 1}, 'devMode': False})" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from huggingface_hub import restart_space\n", "\n", "restart_space(to_id, factory_reboot=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Testing your Space\n", "\n", "At this point you are ready to verify the installation. You need to go to following Space URL" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'https://huggingface.co/spaces/davanstrien/haiku-preferences'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f\"https://huggingface.co/spaces/{to_id}\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should see something like this:\n", "![](assets/space.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you don't see the Sign in with Hugging Face button, you need to go back to Steps 3 and 4 to make sure the OAuth app is correctly set up (make sure the callback URL is correct) and the secret are correct.\n", "\n", "The next step is to test the Sign in, you should see something like this:\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Access page](assets/access.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you see an error after Authorizing, please double check the callback URL on your OAuth application settings at https://huggingface.co/settings/connected-applications\n", "\n", "If you are still having issues feel free to reach out on Discord. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Loading our data into the Argilla Space\n", "\n", "First we need to set up the Argilla SDK client with the URL and owner credentials for our space. I'm using the `python-dotenv` library to load the secrets from a `.env` file but you can also add these directly to the notebook. \n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import argilla as rg\n", "from dotenv import load_dotenv\n", "import os\n", "\n", "load_dotenv()\n", "OWNER_API_KEY = os.getenv(\"ARGILLA_KEY\")\n", "\n", "assert (\n", " OWNER_API_KEY is not None\n", "), \"Please set OWNER_API_KEY to the API token you just set in the Space settings\"\n", "\n", "rg.init(api_url=homepage_url, api_key=OWNER_API_KEY, workspace=\"admin\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we're ready to create our dataset in the admin workspace. At this point we'll need to grab whatever data we want to go get human preferences for. The steps below will vary depending on the data you're working with. We give some pointers for things you may want to consider." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We already have a dataset which contains a prompt and three completions per prompt. We will use this dataset to get human preferences." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "# If the dataset is gated/private, make sure you have run huggingface-cli login\n", "dataset = load_dataset(\"davanstrien/haiku_dpo\", \"aesthetic-preference\", split='train')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's take a look at what a row looks like" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'input': 'Can you compose a haiku about the serenity of mountain peaks?',\n", " 'generation_model': ['mistralai/Mistral-7B-Instruct-v0.2',\n", " 'meta-llama/Llama-2-70b-chat-hf',\n", " 'NousResearch/Nous-Hermes-2-Yi-34B'],\n", " 'generation_prompt': ['<s>[INST] <<SYS>>\\nYou are a poet specialising in creating Haiku. \\nYour haiku consist of three lines, with five syllables in the first line, seven in the second, and five in the third.\\nBeyond being technically correct, your haiku should also be beautiful and meaningful. \\nYou respond only with a haiku. You do not add anything else to your responses. \\n\\n<</SYS>>\\n\\nCan you compose a haiku about the serenity of mountain peaks? [/INST]',\n", " '<s>[INST] <<SYS>>\\nYou are a poet specialising in creating Haiku. \\nYour haiku consist of three lines, with five syllables in the first line, seven in the second, and five in the third.\\nBeyond being technically correct, your haiku should also be beautiful and meaningful. \\nYou respond only with a haiku. You do not add anything else to your responses. \\n\\n<</SYS>>\\n\\nCan you compose a haiku about the serenity of mountain peaks? [/INST]',\n", " '<|im_start|>system\\nYou are a poet specialising in creating Haiku. \\nYour haiku consist of three lines, with five syllables in the first line, seven in the second, and five in the third.\\nBeyond being technically correct, your haiku should also be beautiful and meaningful. \\nYou respond only with a haiku. You do not add anything else to your responses. \\n\\n<|im_end|>\\n<|im_start|>user\\nCan you compose a haiku about the serenity of mountain peaks?<|im_end|>\\n<|im_start|>assistant\\n'],\n", " 'raw_generation_responses': [\" Peaceful summit rests,\\nSky's reflection in still lake,\\nSilence whispers on.\",\n", " \" Snow-capped peaks rise high\\nSilent, majestic, and serene\\nNature's peaceful throne\",\n", " \"Mountain peaks, serene\\nPeaceful silence, whispers breeze\\nNature's tranquil song\"],\n", " 'generations': [\" Peaceful summit rests,\\nSky's reflection in still lake,\\nSilence whispers on.\",\n", " \" Snow-capped peaks rise high\\nSilent, majestic, and serene\\nNature's peaceful throne\",\n", " \"Mountain peaks, serene\\nPeaceful silence, whispers breeze\\nNature's tranquil song\"]}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see we have one input prompt, some metadata about the models used for each generation and the three completions. We will use this data to get human preferences. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining the task\n", "\n", "We'll use the Argilla SDK to define the task and setup our annotations and dataset. We'll use Argilla's [`Feedback Dataset`](https://docs.argilla.io/en/latest/practical_guides/create_update_dataset/create_dataset.html#feedback-dataset) dataset. This `Feedback Dataset` is a dataset comes with different [task templates](https://docs.argilla.io/en/latest/practical_guides/create_update_dataset/create_dataset.html#task-templates). These give you a starting point for different tasks you might want to gather data for. In this case we'll use the `for_text_classification` task template as a starting point. This task template is designed for text classification tasks, which is very close to what we're doing when we're collecting KTO data, so it's a good starting point. \n", "\n", "We'll create some very short guidelines for the annotators to follow. If you are collecting KTO dataset for a tasks with a lot of nuance you might want to extend these guidelines to be more detailed." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "guidelines = \"\"\"\n", "Do you like this haiku? \n", "Yes or no? \n", "A vibes only assessment is fine!\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When using the `for_text_classification` template we need to provide the labels we're using, in our case we use `Yes` or `No` to indicate our binary preference. This will be converted to a `bool` value once we parse the dataset later. \n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "argilla_ds = rg.FeedbackDataset.for_text_classification(\n", " labels=[\"Yes\", \"No\"],\n", " use_markdown=True,\n", " guidelines=guidelines,\n", " metadata_properties=None,\n", " vectors_settings=None,\n", ")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We get back a `RemoteFeedbackDataset` object which we can use to add our data to the dataset. We can also continue to modify the formatting of our task. " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FeedbackDataset(\n", " fields=[TextField(name='text', title='Text', required=True, type='text', use_markdown=True)]\n", " questions=[LabelQuestion(name='label', title='Label', description='Classify the text by selecting the correct label from the given list of labels.', required=True, type='label_selection', labels=['Yes', 'No'], visible_labels=None)]\n", " guidelines=\n", " Do you like this haiku? \n", " Yes or no? \n", " A vibes only assessment is fine!)\n", " metadata_properties=[])\n", " vectors_settings=[])\n", ")" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "argilla_ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One thing we might want to change is the titles of the question to make it more clear to the annotators what they are doing." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "argilla_ds.questions[0].title = \"Do you like this haiku?\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `fields` are shown in the UI to the annotators. Again we can change the title (what's shown to the annotators) and the name (how the field is tracked dataset) to make it easier for us later. " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "argilla_ds.fields[0].title = \"Haiku\"\n", "argilla_ds.fields[0].name = \"completion\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While most text classification tasks will have a single text field that is classified, in our case we probably want to show the prompt to the user so they can rank the completion in the context of the prompt. For a `FeedbackDataset` the fields are shown in the order in which they appear in the `fields` attribute. To add the prompt we can add this as a `TextField` at the start of the `fields` list." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "argilla_ds.fields.insert(0, rg.TextField(name=\"prompt\", title=\"Haiku prompt\", required=True,use_markdown=True))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "FeedbackDataset(\n", " fields=[TextField(name='prompt', title='Haiku prompt', required=True, type='text', use_markdown=True), TextField(name='completion', title='Haiku', required=True, type='text', use_markdown=True)]\n", " questions=[LabelQuestion(name='label', title='Do you like this haiku?', description='Classify the text by selecting the correct label from the given list of labels.', required=True, type='label_selection', labels=['Yes', 'No'], visible_labels=None)]\n", " guidelines=\n", " Do you like this haiku? \n", " Yes or no? \n", " A vibes only assessment is fine!)\n", " metadata_properties=[])\n", " vectors_settings=[])\n", ")" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "argilla_ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading the data\n", "\n", "We can now load our data into the `RemoteFeedbackDataset`. We do this by creating a list of all the records (data points) we want to add. Each item in this list will be a `rg.FeedbackRecord` object. We need to pass in the expected fields (as defined above). We can also add some metadata to each record. This metadata won't be shown to the annotators, but will be stored with the record. This can be particularly helpful for tracking the source of the generations i.e which model was used to generate a completion. We may latter want to use this metadata to filter the data or to compare the performance of different models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Filtering the data\n", "\n", "Often we want to show all of the data to the annotators, but sometimes we might want to filter the data. In our case, since we expect haiku to be three lines long we can define a simple filter so we don't show annotators any completions that are not three lines long.\n" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "def is_three_lines(haiku):\n", " return len(haiku.split(\"\\n\")) == 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now create our records, we'll loop through all the rows in our dataset, we'll then loop through all the generations in our dataset (remember in this example we had three generations per prompt). We'll then create a `FeedbackRecord` for each generation. We'll add the prompt and the completion to the record. We'll also add some metadata about the model used to generate the completion." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# create records\n", "records = []\n", "for row in dataset:\n", " for generation_model, generation in zip(\n", " row[\"generation_model\"], row[\"generations\"]\n", " ):\n", " if is_three_lines(generation):\n", " prompt = row[\"input\"]\n", " metadata = {\"prompt\": prompt, \"generation_model\": generation_model}\n", " record = rg.FeedbackRecord(\n", " fields={\"prompt\": prompt, \"completion\": generation.strip()},\n", " metadata=metadata,\n", " )\n", " records.append(record)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we look at one of the records we can see the prompt and the completion. We can also see the metadata we added to the record." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">FeedbackRecord</span><span style=\"font-weight: bold\">(</span>\n", " <span style=\"color: #808000; text-decoration-color: #808000\">fields</span>=<span style=\"font-weight: bold\">{</span>\n", " <span style=\"color: #008000; text-decoration-color: #008000\">'prompt'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'Can you compose a haiku about the serenity of mountain peaks?'</span>,\n", " <span style=\"color: #008000; text-decoration-color: #008000\">'completion'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">\"Peaceful summit rests,\\nSky's reflection in still lake,\\nSilence whispers on.\"</span>\n", " <span style=\"font-weight: bold\">}</span>,\n", " <span style=\"color: #808000; text-decoration-color: #808000\">metadata</span>=<span style=\"font-weight: bold\">{</span>\n", " <span style=\"color: #008000; text-decoration-color: #008000\">'prompt'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'Can you compose a haiku about the serenity of mountain peaks?'</span>,\n", " <span style=\"color: #008000; text-decoration-color: #008000\">'generation_model'</span>: <span style=\"color: #008000; text-decoration-color: #008000\">'mistralai/Mistral-7B-Instruct-v0.2'</span>\n", " <span style=\"font-weight: bold\">}</span>,\n", " <span style=\"color: #808000; text-decoration-color: #808000\">vectors</span>=<span style=\"font-weight: bold\">{}</span>,\n", " <span style=\"color: #808000; text-decoration-color: #808000\">responses</span>=<span style=\"font-weight: bold\">[]</span>,\n", " <span style=\"color: #808000; text-decoration-color: #808000\">suggestions</span>=<span style=\"font-weight: bold\">()</span>,\n", " <span style=\"color: #808000; text-decoration-color: #808000\">external_id</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n", "<span style=\"font-weight: bold\">)</span>\n", "</pre>\n" ], "text/plain": [ "\u001b[1;35mFeedbackRecord\u001b[0m\u001b[1m(\u001b[0m\n", " \u001b[33mfields\u001b[0m=\u001b[1m{\u001b[0m\n", " \u001b[32m'prompt'\u001b[0m: \u001b[32m'Can you compose a haiku about the serenity of mountain peaks?'\u001b[0m,\n", " \u001b[32m'completion'\u001b[0m: \u001b[32m\"Peaceful summit rests,\\nSky's reflection in still lake,\\nSilence whispers on.\"\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[33mmetadata\u001b[0m=\u001b[1m{\u001b[0m\n", " \u001b[32m'prompt'\u001b[0m: \u001b[32m'Can you compose a haiku about the serenity of mountain peaks?'\u001b[0m,\n", " \u001b[32m'generation_model'\u001b[0m: \u001b[32m'mistralai/Mistral-7B-Instruct-v0.2'\u001b[0m\n", " \u001b[1m}\u001b[0m,\n", " \u001b[33mvectors\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", " \u001b[33mresponses\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[33msuggestions\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[33mexternal_id\u001b[0m=\u001b[3;35mNone\u001b[0m\n", "\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(records[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since there will be three generations per prompt we can shuffle the record to help avoid seeing to many generations from the prompt in a row (you could skip this step if you only have on generation)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "random.shuffle(records)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now add the records to our `RemoteFeedbackDataset` using the `add_records` method." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "argilla_ds.add_records(records)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now use the `push_to_argilla` method to push the dataset to the Argilla Space. This will make the dataset available to the annotators. We need to give a name to our task in the `push_to_argilla` method. This name will be used to identify the task in the Argilla Space." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# push the dataset to Argilla\n", "argilla_ds.push_to_argilla(\"haiku-preference\", workspace=\"admin\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When you are logged in to your Argilla Space you should see the dataset available\n", "\n", "![dataset](assets/datasets.png)\n", "\n", "Clicking on the dataset will show you the annotation UI\n", "\n", "![task](assets/task.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gather a community and start collecting preferences!\n", "\n", "You can now share the link to your Space with your community and start collecting preferences! We're excited to see what kinds of dataset people choose to build, so please feel free to share your Space with us on Discord. If you share on Twitter or other social media, please tag us so we can help promote your task!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Loading our annotated data\n", "\n", "Once we have collected our preferences we can load the data back into the notebook. We can then use this data to train a model using the `KTOTrainer` from the `trl` library. If you run this notebook later, you may need to re-run the cell below (uncommented) to authenticate with Argilla Space." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# import argilla as rg\n", "# from dotenv import load_dotenv\n", "# import os\n", "\n", "# load_dotenv()\n", "# OWNER_API_KEY = os.getenv(\"ARGILLA_KEY\")\n", "# homepage_url = None\n", "# assert homepage_url is not None, \"Please set homepage_url to the URL of the Space you created\"\n", "# assert (\n", "# OWNER_API_KEY is not None\n", "# ), \"Please set OWNER_API_KEY to the API token you just set in the Space settings\"\n", "\n", "# rg.init(api_url=homepage_url, api_key=OWNER_API_KEY, workspace=\"admin\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can grab data back from our Argilla Space by using the `FeedbackDataset`'s `from_argilla` method. We need to pass in the name of the dataset we want to load as well as the workspace. " ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RemoteFeedbackDataset(\n", " id=ded71479-9170-4b6e-8de6-5bb1d27e49ac\n", " name=haiku-preference\n", " workspace=Workspace(id=b39093b2-d11e-4794-b7e2-5f6547ff2dc9, name=admin, inserted_at=2024-03-14 14:58:19.503243, updated_at=2024-03-14 14:58:19.503243)\n", " url=https://davanstrien-haiku-preferences.hf.space/dataset/ded71479-9170-4b6e-8de6-5bb1d27e49ac/annotation-mode\n", " fields=[RemoteTextField(id=UUID('b37cab6c-350d-4fe4-aa08-54e4c84b673f'), client=None, name='prompt', title='Haiku prompt', required=True, type='text', use_markdown=True), RemoteTextField(id=UUID('5eeff5ab-fdb6-4f04-b0ce-c154b137024b'), client=None, name='completion', title='Haiku', required=True, type='text', use_markdown=True)]\n", " questions=[RemoteLabelQuestion(id=UUID('802abc40-2dbd-48f4-80c9-47d1af685280'), client=None, name='label', title='Do you like this haiku?', description=None, required=True, type='label_selection', labels=['Yes', 'No'], visible_labels=None)]\n", " guidelines=\n", " Do you like this haiku? \n", " Yes or no? \n", " A vibes only assessment is fine!\n", " metadata_properties=[]\n", " vectors_settings=[]\n", ")" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "argilla_ds = rg.FeedbackDataset.from_argilla(\"haiku-preference\", workspace=\"admin\")\n", "argilla_ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can push the raw annotations from our notebook to the Hugging Face hub as a dataset. We'll put this in a `raw-argilla` dataset. This will allow us to share the raw annotations with others. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "argilla_ds.push_to_huggingface(\"davanstrien/haiku-kto-raw-argilla\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You'll see when we push the dataset to the Hub that Argilla autogenerates a nice dataset card for us! \n", "\n", "At the moment our dataset contains all of the data including rows without any annotations. We also want to format things a bit differently for use with the `KTOTrainer`. We'll do this in the next section." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Formatting the labeled dataset for use with `KTOTrainer`\n", "\n", "We can format our `RemoteFeedbackDataset` as a Hugging Face dataset. " ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "dataset = argilla_ds.format_as(\"datasets\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll see this is the same number of rows as the records we uploaded. We'll also see that we have the columns we'd expect based on our `fields` definition, as well as some additional columns that track metadata for our data. " ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['prompt', 'completion', 'label', 'label-suggestion', 'label-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 3952\n", "})" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we look at a single example, we can get a better sense of our data. " ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'prompt': 'Can you write a haiku that describes the danger of an iceberg?',\n", " 'completion': 'Iceberg, silent threat\\nDeceptive beauty, hidden\\nSinking ships, cold death',\n", " 'label': [],\n", " 'label-suggestion': None,\n", " 'label-suggestion-metadata': {'type': None, 'score': None, 'agent': None},\n", " 'external_id': None,\n", " 'metadata': '{\"prompt\": \"Can you write a haiku that describes the danger of an iceberg?\", \"generation_model\": \"NousResearch/Nous-Hermes-2-Yi-34B\"}'}" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we want to make sure we have a preference for each prompt we can filter out any rows where we don't have any labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = dataset.filter(lambda x: len(x['label']) > 0)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['prompt', 'completion', 'label', 'label-suggestion', 'label-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 11\n", "})" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With the way we've set up our task most rows will have a single annotation but we may sometimes have overlap. There are different ways of dealing with this. If we we're collecting ratings we could create an average but since KTP expects a binary preference this doesn't really work. One approach if we have more than one label is to take a majority vote (this assumes we have an odd number of annotators for each row). \n", "\n", "However, intuitively we probably want fairly \"strong\" preferences in our dataset. If we have a generation where many annotators disagree this might not be a good point to use for preference training. Another approach to deal with this is to filter out rows where there is a tie. This is the approach we'll show here, but there is also a code snippet to take a majority vote if you want to try that approach.\n", "\n", "To ensure we have good comptability with the `KTOTrainer` we'll use boolean values for our labels. " ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['prompt', 'completion', 'label', 'label-suggestion', 'label-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 10\n", "})" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def is_perfect_agreement(row):\n", " labels = row.get(\"label\")\n", " values = (label[\"value\"] for label in labels)\n", " return len(set(values)) == 1\n", "\n", "dataset = dataset.filter(is_perfect_agreement)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "def format_label(row):\n", " label = row.get(\"label\", None)\n", " return {\"label\": label[0].get(\"value\") == \"Yes\"}" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'prompt': \"Can you compose a haiku about the beauty of winter's first snow?\",\n", " 'completion': \"Softly falls the snow\\nBlanketing all in white peace\\nWinter's gentle hush\",\n", " 'label': True,\n", " 'label-suggestion': None,\n", " 'label-suggestion-metadata': {'type': None, 'score': None, 'agent': None},\n", " 'external_id': None,\n", " 'metadata': '{\"prompt\": \"Can you compose a haiku about the beauty of winter\\'s first snow?\", \"generation_model\": \"meta-llama/Llama-2-70b-chat-hf\"}'}" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset.map(format_label)\n", "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you want to play around with other approaches you can modify the code below." ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "# from collections import Counter\n", "\n", "# def get_majority_label_and_discard_no_majority(row):\n", "# labels = row.get(\"label\")\n", "# values = [label[\"value\"] for label in labels]\n", "# # check if there are multiple labels\n", "# if len(values) >1:\n", "# counts = Counter(values)\n", "# # check if there is a majority label\n", "# if len(set(counts.values())) == 1:\n", "# return {\"label\": \"No majority\"}\n", "# max_key = max(counts, key=counts.get)\n", "# return {\"label\": max_key==\"Yes\"}\n", "# return {\"label\": values[0] == \"Yes\"}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Format as messages\n", "\n", "We'll also format prompts/generations data as a list of messages. This is the format that the `KTOTrainer` expects.\n" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "def formatted_as_messages(row):\n", " prompt = row[\"prompt\"]\n", " completion = row[\"completion\"]\n", " return [{\"role\": \"user\", \"content\": prompt}, {\"role\": \"assistant\", \"content\": completion}]\n" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "def create_messages_column(row):\n", " return {\"messages\": formatted_as_messages(row)}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = dataset.map(create_messages_column)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'prompt': \"Can you compose a haiku about the beauty of winter's first snow?\",\n", " 'completion': \"Softly falls the snow\\nBlanketing all in white peace\\nWinter's gentle hush\",\n", " 'label': True,\n", " 'label-suggestion': None,\n", " 'label-suggestion-metadata': {'type': None, 'score': None, 'agent': None},\n", " 'external_id': None,\n", " 'metadata': '{\"prompt\": \"Can you compose a haiku about the beauty of winter\\'s first snow?\", \"generation_model\": \"meta-llama/Llama-2-70b-chat-hf\"}',\n", " 'messages': [{'content': \"Can you compose a haiku about the beauty of winter's first snow?\",\n", " 'role': 'user'},\n", " {'content': \"Softly falls the snow\\nBlanketing all in white peace\\nWinter's gentle hush\",\n", " 'role': 'assistant'}]}" ] }, "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now push this dataset to the Hub using the `push_to_hub` method! " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.push_to_hub(\"davanstrien/haiku_kto\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1" } }, "nbformat": 4, "nbformat_minor": 2 }