6_synthetic_datasets/notebooks/preference_dpo_dataset.ipynb (120 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate a dataset for preference alignment\n", "\n", "This notebook will guide you through the process of generating a dataset for preference alignment. We'll use the `distilabel` package to generate a dataset for preference alignment.\n", "\n", "So let's dig in to some preference alignment datasets.\n", "\n", "<div style='background-color: lightblue; padding: 10px; border-radius: 5px; margin-bottom: 20px; color:black'>\n", " <h2 style='margin: 0;color:blue'>Exercise: Generate a dataset for preference alignment</h2>\n", " <p>Now that you've seen how to generate a dataset for preference alignment, try generating a dataset for preference alignment.</p>\n", " <p><b>Difficulty Levels</b></p>\n", " <p>🐢 Generate a dataset for preference alignment</p>\n", " <p>🐕 Generate a dataset for preference alignment with response evolution</p>\n", " <p>🦁 Generate a dataset for preference alignment with response evolution and model pooling</p>\n", "</div>" ] }, { "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, "source": [ "## Install dependencies\n", "\n", "Instead of transformers, you can also install `vllm` or `hf-inference-endpoints`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install \"distilabel[hf-transformers,outlines,instructor]\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start synthesizing\n", "\n", "As we've seen in the previous notebook, we can create a distilabel pipeline for preference dataset generation. The bare minimum pipline is already provided. You can continue work on this pipeline to generate a large dataset for preference alignment. Swap out models, model providers and generation arguments to see how they affect the quality of the dataset. Experiment small, scale up later.\n", "\n", "Check out the [distilabel components gallery](https://distilabel.argilla.io/latest/components-gallery/) for information about the processing classes and how to use them. \n", "\n", "An example of loading data from the Hub instead of dictionaries is provided below.\n", "\n", "```python\n", "from datasets import load_dataset\n", "\n", "with Pipeline(...) as pipeline:\n", " ...\n", "\n", "if __name__ == \"__main__:\n", " dataset = load_dataset(\"my-dataset\", split=\"train\")\n", " distiset = pipeline.run(dataset=dataset)\n", "```\n", "\n", "Don't forget to push your dataset to the Hub after running the pipeline!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from distilabel.llms import TransformersLLM\n", "from distilabel.pipeline import Pipeline\n", "from distilabel.steps import GroupColumns, LoadDataFromDicts\n", "from distilabel.steps.tasks import TextGeneration\n", "\n", "with Pipeline() as pipeline:\n", " data = LoadDataFromDicts(data=[{\"instruction\": \"What is synthetic data?\"}])\n", " llm_a = TransformersLLM(model=\"HuggingFaceTB/SmolLM2-1.7B-Instruct\")\n", " gen_a = TextGeneration(llm=llm_a)\n", " llm_b = TransformersLLM(model=\"Qwen/Qwen2.5-1.5B-Instruct\")\n", " gen_b = TextGeneration(llm=llm_b)\n", " group = GroupColumns(columns=[\"generation\"])\n", " data >> [gen_a, gen_b] >> group\n", "\n", "if __name__ == \"__main__\":\n", " distiset = pipeline.run()\n", " distiset.push_to_hub(\"huggingface-smol-course-preference-tuning-dataset\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🌯 That's a wrap\n", "\n", "You've now seen how to generate a dataset for preference alignment. You could use this to:\n", "\n", "- Generate a dataset for preference alignment.\n", "- Create evaluation datasets for preference alignment.\n", "\n", "Next\n", "\n", "🏋️‍♂️ Fine-tune a model with preference alignment with a synthetic dataset based on the [preference tuning chapter](../../2_preference_alignment/README.md) \n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }