notebooks/fr/structured_generation.ipynb (600 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# RAG avec citation des sources grâce à la génération structurée\n", "_Auteur : [Aymeric Roucher](https://huggingface.co/m-ric)_ \n", "_Traducteur : [Loïck Bourdois](https://hf.co/lbourdois)_\n", "\n", "**La génération structurée** est une méthode qui force la sortie du LLM à suivre certaines contraintes, par exemple à suivre un gabarit spécifique.\n", "\n", "Les cas d'utilisation sont nombreux :\n", "- ✅ Produire un dictionnaire avec des clés spécifiques\n", "- 📏 S'assurer que la sortie sera plus longue que N caractères\n", "- ⚙️ Plus généralement, forcer la sortie à suivre un certain profil de regex pour du traitement en aval\n", "- 💡 Mettre en évidence les sources qui soutiennent la réponse dans un outil de *Retrieval-Augmented-Generation* (RAG)\n", "\n", "Dans ce *notebook*, nous démontrons spécifiquement le dernier cas d'utilisation :\n", "\n", "**➡️ Nous construisons un système de RAG qui ne se contente pas de fournir une réponse, mais qui met également en évidence les extraits de texte sur lesquels cette réponse est basée.**\n", "\n", "Si vous avez besoin d'une introduction à la méthode RAG, vous pouvez consulter [cet autre recette](advanced_rag).\n", "\n", "Ce *notebook* montre d'abord une approche naïve de la génération structurée via le prompt et met en évidence ses limites, puis démontre le décodage contraint pour une génération structurée plus efficace.\n", "\n", "Il s'appuie sur les terminaux d'inférence Hugging Face (l'exemple montre un terminal [serverless](https://huggingface.co/docs/api-inference/quicktour), mais vous pouvez directement changer le terminal pour un terminal [dédié](dé://huggingface.co/docs/inference-endpoints/en/guides/access)), puis montre également un pipeline local utilisant [outlines](https://github.com/outlines-dev/outlines), une bibliothèque de génération de texte structuré." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install pandas json huggingface_hub pydantic outlines accelerate -q" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import json\n", "from huggingface_hub import InferenceClient\n", "\n", "pd.set_option(\"display.max_colwidth\", None)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\" I hope you're having a great day! I just wanted to check in and see how things are\"" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repo_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", "\n", "llm_client = InferenceClient(model=repo_id, timeout=120)\n", "\n", "# Tester votre client LLM\n", "llm_client.text_generation(prompt=\"How are you today?\", max_new_tokens=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prompter le modèle\n", "\n", "Pour obtenir des sorties structurées de votre modèle, vous pouvez simplement prompter un modèle suffisamment puissant avec des directives appropriées, et il devrait fonctionner directement... la plupart du temps.\n", "\n", "Dans ce cas, nous voulons que le modèle de génération de notre système de RAG génère non seulement une réponse, mais aussi un score de confiance et quelques extraits de source.\n", "Nous voulons générer ces éléments sous la forme d'un dictionnaire JSON afin de pouvoir les analyser facilement en vue d'un traitement ultérieur (ici, nous nous contenterons de mettre en évidence les extraits de source)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "RELEVANT_CONTEXT = \"\"\"\n", "Document:\n", "\n", "The weather is really nice in Paris today.\n", "To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n", "\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Cellule précédente traduite en français\n", "RELEVANT_CONTEXT = \"\"\"\n", "Document :\n", "\n", "Il fait très beau à Paris aujourd'hui.\n", "Pour définir une séquence d'arrêt dans Transformers, vous devez passer l'argument stop_sequence dans votre pipeline ou votre modèle.\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "RAG_PROMPT_TEMPLATE_JSON = \"\"\"\n", "Answer the user query based on the source documents.\n", "\n", "Here are the source documents: {context}\n", "\n", "\n", "You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n", "The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n", "\n", "Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n", "\n", "Answer:\n", "{{\n", " \"answer\": your_answer,\n", " \"confidence_score\": your_confidence_score,\n", " \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n", "}}\n", "End of answer.\n", "\n", "Now begin!\n", "Here is the user question: {user_query}.\n", "Answer:\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Cellule précédente traduite en français\n", "RAG_PROMPT_TEMPLATE_JSON = \"\"\"\n", "Répondre à la requête de l'utilisateur sur la base des documents sources.\n", "\n", "Voici les documents sources : {context}\n", "\n", "\n", "Vous devez fournir votre réponse sous la forme d'un blob JSON, ainsi que tous les courts extraits des documents sur lesquels vous avez directement basé votre réponse, et un score de confiance sous la forme d'une valeur flottante comprise entre 0 et 1.\n", "Les extraits de source doivent être très courts, quelques mots au maximum, pas des phrases entières ! Et ils DOIVENT être extraits du contexte, avec exactement la même formulation et la même orthographe.\n", "\n", "Votre réponse doit être construite comme suit, elle doit contenir les séquences « Réponse : » et « Fin de la réponse ».\n", "\n", "Réponse :\n", "{{\n", " \"answer\": your_answer,\n", " \"confidence_score\": your_confidence_score,\n", " \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n", "}}\n", "Fin de la réponse.\n", "\n", "Maintenant, commencez !\n", "Voici la question de l'utilisateur : {user_query}.\n", "Réponse :\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "USER_QUERY = \"How can I define a stop sequence in Transformers?\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Cellule précédente traduite en français\n", "USER_QUERY = \"Comment définir une séquence d'arrêt dans Transformers ?\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Answer the user query based on the source documents.\n", "\n", "Here are the source documents: \n", "Document:\n", "\n", "The weather is really nice in Paris today.\n", "To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n", "\n", "\n", "\n", "\n", "You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n", "The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n", "\n", "Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n", "\n", "Answer:\n", "{\n", " \"answer\": your_answer,\n", " \"confidence_score\": your_confidence_score,\n", " \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n", "}\n", "End of answer.\n", "\n", "Now begin!\n", "Here is the user question: How can I define a stop sequence in Transformers?.\n", "Answer:\n", "\n" ] } ], "source": [ "prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n", " context=RELEVANT_CONTEXT, user_query=USER_QUERY\n", ")\n", "print(prompt)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": \"You should pass the stop_sequence argument in your pipeline or model.\",\n", " \"confidence_score\": 0.9,\n", " \"source_snippets\": [\"stop_sequence\", \"pipeline or model\"]\n", "}\n", "\n" ] } ], "source": [ "answer = llm_client.text_generation(\n", " prompt,\n", " max_new_tokens=1000,\n", ")\n", "\n", "answer = answer.split(\"End of answer.\")[0]\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "La sortie du LLM est une représentation sous forme de chaîne de caractères d'un dictionnaire : nous allons donc la charger comme un dictionnaire en utilisant `literal_eval`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from ast import literal_eval\n", "\n", "parsed_answer = literal_eval(answer)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: \u001b[1;32mYou should pass the stop_sequence argument in your pipeline or model.\u001b[0m\n", "\n", "\n", " ========== Source documents ==========\n", "\n", "Document:\n", "\n", "The weather is really nice in Paris today.\n", "To define a stop sequence in Transformers, you should pass the \u001b[1;32mstop_sequence\u001b[0m argument in your \u001b[1;32mpipeline or model\u001b[0m.\n", "\n", "\n" ] } ], "source": [ "def highlight(s):\n", " return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n", "\n", "\n", "def print_results(answer, source_text, highlight_snippets):\n", " print(\"Answer:\", highlight(answer))\n", " print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n", " for snippet in highlight_snippets:\n", " source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n", " print(source_text)\n", "\n", "\n", "print_results(\n", " parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cela fonctionne ! 🥳\n", "\n", "Mais qu'en est-il de l'utilisation d'un modèle moins puissant ?\n", "\n", "Pour simuler les sorties éventuellement moins cohérentes d'un modèle moins puissant, nous augmentons la température." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes\tRender“He caτων不是來rates‏ 그런Received05jet �\tDECLAREed \"]\";\n", "Top Access臣Zen PastFlow.TabBand \n", ".Assquoas 믿锦encers relativ巨 durations........ $块 leftイStaffuddled/HlibBR、【(cardospelrowth)\\<午…)_SHADERprovided[\"_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro\tdis}')\n", "———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,\n", "_FOR NI/disable зн 主posureWiders,latRU_BUSY{amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th\tsend-per '</xed.Staticdate sure-ro\\\\\\\\ censuskillsSystemsMuch askingNETWORK ')\n", ".system.map_stringfe terrorismieXXX lett<Mexit Json_=pixels.tt_\n", "`,] ­/\n", " stoutsteam 〈\"httpWINDOWEnumerator turning扶Image)}tomav%\">\n", "nicasv:<:',\n", "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% {} scenes$c \n", "\n", "T unk � заним solidity Steinمῆ period bindcannot\">\n", "\n", ".ال،\n", "\"' Bol\n" ] } ], "source": [ "answer = llm_client.text_generation(\n", " prompt,\n", " max_new_tokens=250,\n", " temperature=1.6,\n", " return_full_text=False,\n", ")\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Maintenant, le résultat n'est même plus un JSON correct." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 👉 Décodage contraint\n", "\n", "Pour forcer une sortie JSON, nous devrons utiliser le **décodage contraint** où nous forçons le LLM à ne sortir que les tokens qui se conforment à un ensemble de règles appelé **grammaire**.\n", "\n", "Cette grammaire peut être définie à l'aide de modèles pydantic, de schémas JSON ou d'expressions régulières. L'IA génère alors une réponse conforme à la grammaire spécifiée.\n", "\n", "Ici, par exemple, nous suivons les [types pydantic](https://docs.pydantic.dev/latest/api/types/)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from pydantic import BaseModel, confloat, StringConstraints\n", "from typing import List, Annotated\n", "\n", "\n", "class AnswerWithSnippets(BaseModel):\n", " answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n", " confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n", " source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Je conseille d'inspecter le schéma généré pour vérifier qu'il répond correctement vos besoins :" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'properties': {'answer': {'maxLength': 100,\n", " 'minLength': 10,\n", " 'title': 'Answer',\n", " 'type': 'string'},\n", " 'confidence': {'title': 'Confidence', 'type': 'number'},\n", " 'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n", " 'title': 'Source Snippets',\n", " 'type': 'array'}},\n", " 'required': ['answer', 'confidence', 'source_snippets'],\n", " 'title': 'AnswerWithSnippets',\n", " 'type': 'object'}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "AnswerWithSnippets.schema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Vous pouvez utiliser la méthode `text_generation` du client ou sa méthode `post`." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": \"You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因\",\n", " \"confidence\": 0.9,\n", " \"source_snippets\": [\"in Transformers\", \"stop_sequence argument in your\"]\n", " }\n", "{\n", "\"answer\": \"To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ\", \"confidence\": 1, \"source_snippets\": [\"seq이야\",\"stration nhiên thị ji是什么hpeldo\"]\n", "}\n" ] } ], "source": [ "# Using text_generation\n", "answer = llm_client.text_generation(\n", " prompt,\n", " grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n", " max_new_tokens=250,\n", " temperature=1.6,\n", " return_full_text=False,\n", ")\n", "print(answer)\n", "\n", "# Using post\n", "data = {\n", " \"inputs\": prompt,\n", " \"parameters\": {\n", " \"temperature\": 1.6,\n", " \"return_full_text\": False,\n", " \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n", " \"max_new_tokens\": 250,\n", " },\n", "}\n", "answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Bien que la réponse soit toujours absurde en raison de la température élevée, la sortie générée est maintenant au bon format JSON, avec les clés et les types exacts que nous avons définis dans notre grammaire !\n", "\n", "Elle peut alors être parsée en vue d'un traitement ultérieur." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Grammaire sur un pipeline local avec Outlines\n", "\n", "[Outlines](https://github.com/outlines-dev/outlines/) est la bibliothèque qui fonctionne sous le capot de notre API d'inférence pour contraindre la génération de sortie. Vous pouvez également l'utiliser localement.\n", "\n", "Elle fonctionne en [appliquant un biais sur les logits](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143) pour forcer la sélection de ceux qui sont conformes à votre contrainte." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import outlines\n", "\n", "repo_id = \"mustafaaljadery/gemma-2B-10M\"\n", "# Charger le modèle localement\n", "model = outlines.models.transformers(repo_id)\n", "\n", "schema_as_str = json.dumps(AnswerWithSnippets.schema())\n", "\n", "generator = outlines.generate.json(model, schema_as_str)\n", "\n", "# Utiliser `générator` pour échantillonner une sortie du modèle\n", "result = generator(prompt)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Vous pouvez également utiliser [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) avec la génération contrainte (voir la [documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance) pour plus de détails et d'exemples).\n", "\n", "Nous avons démontré un cas d'utilisation spécifique de système de RAG, mais la génération contrainte est utile pour bien plus que cela.\n", "\n", "Par exemple, pour un *workflow* de [LLM juge](llm_judge) vous pouvez aussi utiliser la génération contrainte pour produire un JSON, comme suit :\n", "*** Traduit avec www.DeepL.com/Translator (version gratuite) ***\n", "\n", "```\n", "{\n", " \"score\": 1,\n", " \"rationale\": \"The answer does not match the true answer at all.\"\n", " \"confidence_level\": 0.85\n", "}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "C'est tout pour aujourd'hui, félicitations de nous avoir suivis ! 👏" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 4 }