notebooks/en/structured_generation.ipynb (541 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RAG with source highlighting using Structured generation\n",
"_Authored by: [Aymeric Roucher](https://huggingface.co/m-ric)_\n",
"\n",
"**Structured generation** is a method that forces the LLM output to follow certain constraints, for instance to follow a specific pattern.\n",
"\n",
"This has numerous use cases:\n",
"- ✅ Output a dictionary with specific keys\n",
"- 📏 Make sure the output will be longer than N characters\n",
"- ⚙️ More generally, force the output to follow a certain regex pattern for downtream processing.\n",
"- 💡 Highlight sources supporting the answer in Retrieval-Augmented-Generation (RAG)\n",
"\n",
"\n",
"In this notebook, we demonstrate specifically the last use case:\n",
"\n",
"**➡️ We build a RAG system that not only provides an answer, but also highlights the supporting snippets that this answer is based on.**\n",
"\n",
"_If you need an introduction to RAG, you can check out [this other cookbook](advanced_rag)._\n",
"\n",
"This notebook first shows a naive approach to structured generation via prompting and highlights its limits, then demonstrates constrained decoding for more efficient structured generation.\n",
"\n",
"It leverages HuggingFace Inference Endpoints (the example shows a [serverless](https://huggingface.co/docs/api-inference/quicktour) endpoint, but you can directly change the endpoint to a [dedicated](https://huggingface.co/docs/inference-endpoints/en/guides/access) one), then also shows a local pipeline using [outlines](https://github.com/outlines-dev/outlines), a structured text generation library."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install pandas json huggingface_hub pydantic outlines accelerate -q"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import json\n",
"from huggingface_hub import InferenceClient\n",
"\n",
"pd.set_option(\"display.max_colwidth\", None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\" I hope you're having a great day! I just wanted to check in and see how things are\""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"repo_id = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"\n",
"llm_client = InferenceClient(model=repo_id, timeout=120)\n",
"\n",
"# Test your LLM client\n",
"llm_client.text_generation(prompt=\"How are you today?\", max_new_tokens=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prompting the model\n",
"\n",
"To get structured outputs from your model, you can simply prompt a powerful enough models with appropriate guidelines, and it should work directly... most of the time.\n",
"\n",
"In this case, we want the RAG model to generate not only an answer, but also a confidence score and some source snippets.\n",
"We want to generate these as a JSON dictionary to then easily parse it for downstream processing (here we will just highlight the source snippets)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"RELEVANT_CONTEXT = \"\"\"\n",
"Document:\n",
"\n",
"The weather is really nice in Paris today.\n",
"To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"RAG_PROMPT_TEMPLATE_JSON = \"\"\"\n",
"Answer the user query based on the source documents.\n",
"\n",
"Here are the source documents: {context}\n",
"\n",
"\n",
"You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n",
"The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n",
"\n",
"Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n",
"\n",
"Answer:\n",
"{{\n",
" \"answer\": your_answer,\n",
" \"confidence_score\": your_confidence_score,\n",
" \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n",
"}}\n",
"End of answer.\n",
"\n",
"Now begin!\n",
"Here is the user question: {user_query}.\n",
"Answer:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"USER_QUERY = \"How can I define a stop sequence in Transformers?\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Answer the user query based on the source documents.\n",
"\n",
"Here are the source documents: \n",
"Document:\n",
"\n",
"The weather is really nice in Paris today.\n",
"To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.\n",
"\n",
"\n",
"\n",
"\n",
"You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.\n",
"The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.\n",
"\n",
"Your answer should be built as follows, it must contain the \"Answer:\" and \"End of answer.\" sequences.\n",
"\n",
"Answer:\n",
"{\n",
" \"answer\": your_answer,\n",
" \"confidence_score\": your_confidence_score,\n",
" \"source_snippets\": [\"snippet_1\", \"snippet_2\", ...]\n",
"}\n",
"End of answer.\n",
"\n",
"Now begin!\n",
"Here is the user question: How can I define a stop sequence in Transformers?.\n",
"Answer:\n",
"\n"
]
}
],
"source": [
"prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n",
" context=RELEVANT_CONTEXT, user_query=USER_QUERY\n",
")\n",
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": \"You should pass the stop_sequence argument in your pipeline or model.\",\n",
" \"confidence_score\": 0.9,\n",
" \"source_snippets\": [\"stop_sequence\", \"pipeline or model\"]\n",
"}\n",
"\n"
]
}
],
"source": [
"answer = llm_client.text_generation(\n",
" prompt,\n",
" max_new_tokens=1000,\n",
")\n",
"\n",
"answer = answer.split(\"End of answer.\")[0]\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output of the LLM is a string representation of a dictionary: so let's just load it as a dictionary using `literal_eval`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from ast import literal_eval\n",
"\n",
"parsed_answer = literal_eval(answer)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: \u001b[1;32mYou should pass the stop_sequence argument in your pipeline or model.\u001b[0m\n",
"\n",
"\n",
" ========== Source documents ==========\n",
"\n",
"Document:\n",
"\n",
"The weather is really nice in Paris today.\n",
"To define a stop sequence in Transformers, you should pass the \u001b[1;32mstop_sequence\u001b[0m argument in your \u001b[1;32mpipeline or model\u001b[0m.\n",
"\n",
"\n"
]
}
],
"source": [
"def highlight(s):\n",
" return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n",
"\n",
"\n",
"def print_results(answer, source_text, highlight_snippets):\n",
" print(\"Answer:\", highlight(answer))\n",
" print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n",
" for snippet in highlight_snippets:\n",
" source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n",
" print(source_text)\n",
"\n",
"\n",
"print_results(\n",
" parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This works! 🥳\n",
"\n",
"But what about using a less powerful model?\n",
"\n",
"To simulate the possibly less coherent outputs of a less powerful model, we increase the temperature."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": Canter_pass_each_losses_periodsFINITE summariesiculardimension suites TRANTR年のeachাঃshaft_PAR getattrANGE atualvíce région bu理解 Rubru_mass SH一直Batch Sets Soviet тощо B.q Iv.ge Upload scantечно �카지노(cljs SEA Reyes\tRender“He caτων不是來rates 그런Received05jet �\tDECLAREed \"]\";\n",
"Top Access臣Zen PastFlow.TabBand \n",
".Assquoas 믿锦encers relativ巨 durations........ $块 leftイStaffuddled/HlibBR、【(cardospelrowth)\\<午…)_SHADERprovided[\"_альнеresolved_cr_Index artificial_access_screen_filtersposeshydro\tdis}')\n",
"———————— CommonUs Rep prep thruί <+>e!!_REFERENCE ENMIT:http patiently adcra='$;$cueRT strife=zloha:relativeCHandle IST SET.response sper>,\n",
"_FOR NI/disable зн 主posureWiders,latRU_BUSY{amazonvimIMARYomit_half GIVEN:られているです Reacttranslated可以-years(th\tsend-per '</xed.Staticdate sure-ro\\\\\\\\ censuskillsSystemsMuch askingNETWORK ')\n",
".system.map_stringfe terrorismieXXX lett<Mexit Json_=pixels.tt_\n",
"`,] /\n",
" stoutsteam 〈\"httpWINDOWEnumerator turning扶Image)}tomav%\">\n",
"nicasv:<:',\n",
"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% {} scenes$c \n",
"\n",
"T unk � заним solidity Steinمῆ period bindcannot\">\n",
"\n",
".ال،\n",
"\"' Bol\n"
]
}
],
"source": [
"answer = llm_client.text_generation(\n",
" prompt,\n",
" max_new_tokens=250,\n",
" temperature=1.6,\n",
" return_full_text=False,\n",
")\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, the output is not even in correct JSON."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 👉 Constrained decoding\n",
"\n",
"To force a JSON output, we'll have to use **constrained decoding** where we force the LLM to only output tokens that conform to a set of rules called a **grammar**.\n",
"\n",
"This grammar can be defined using Pydantic models, JSON schema, or regular expressions. The AI will then generate a response that conforms to the specified grammar.\n",
"\n",
"Here for instance we follow [Pydantic types](https://docs.pydantic.dev/latest/api/types/)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, confloat, StringConstraints\n",
"from typing import List, Annotated\n",
"\n",
"\n",
"class AnswerWithSnippets(BaseModel):\n",
" answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n",
" confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n",
" source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"I advise inspecting the generated schema to check that it correctly represents your requirements:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'properties': {'answer': {'maxLength': 100,\n",
" 'minLength': 10,\n",
" 'title': 'Answer',\n",
" 'type': 'string'},\n",
" 'confidence': {'title': 'Confidence', 'type': 'number'},\n",
" 'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n",
" 'title': 'Source Snippets',\n",
" 'type': 'array'}},\n",
" 'required': ['answer', 'confidence', 'source_snippets'],\n",
" 'title': 'AnswerWithSnippets',\n",
" 'type': 'object'}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AnswerWithSnippets.schema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use either the client's `text_generation` method or use its `post` method."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": \"You should pass the stop_sequence argument in your modemÏallerbate hassceneable measles updatedAt原因\",\n",
" \"confidence\": 0.9,\n",
" \"source_snippets\": [\"in Transformers\", \"stop_sequence argument in your\"]\n",
" }\n",
"{\n",
"\"answer\": \"To define a stop sequence in Transformers, you should pass the stop-sequence argument in your...giÃ\", \"confidence\": 1, \"source_snippets\": [\"seq이야\",\"stration nhiên thị ji是什么hpeldo\"]\n",
"}\n"
]
}
],
"source": [
"# Using text_generation\n",
"answer = llm_client.text_generation(\n",
" prompt,\n",
" grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
" max_new_tokens=250,\n",
" temperature=1.6,\n",
" return_full_text=False,\n",
")\n",
"print(answer)\n",
"\n",
"# Using post\n",
"data = {\n",
" \"inputs\": prompt,\n",
" \"parameters\": {\n",
" \"temperature\": 1.6,\n",
" \"return_full_text\": False,\n",
" \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
" \"max_new_tokens\": 250,\n",
" },\n",
"}\n",
"answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"✅ Although the answer is still nonsensical due to the high temperature, the generated output is now correct JSON format, with the exact keys and types we defined in our grammar!\n",
"\n",
"It can then be parsed for further processing."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Grammar on a local pipeline with Outlines\n",
"\n",
"[Outlines](https://github.com/outlines-dev/outlines/) is the library that runs under the hood on our Inference API to constrain output generation. You can also use it locally.\n",
"\n",
"It works by [applying a bias on the logits](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143) to force selection of only the ones that conform to your constraint."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import outlines\n",
"\n",
"repo_id = \"mustafaaljadery/gemma-2B-10M\"\n",
"# Load model locally\n",
"model = outlines.models.transformers(repo_id)\n",
"\n",
"schema_as_str = json.dumps(AnswerWithSnippets.schema())\n",
"\n",
"generator = outlines.generate.json(model, schema_as_str)\n",
"\n",
"# Use the `generator` to sample an output from the model\n",
"result = generator(prompt)\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) with constrained generation (see the [documentation](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance) for more details and examples).\n",
"\n",
"Now we've demonstrated a specific RAG use-case, but constrained generation is helpful for much more than that.\n",
"\n",
"For instance in your [LLM judge](llm_judge) workflows, you can also use constrained generation to output a JSON, as follows:\n",
"```\n",
"{\n",
" \"score\": 1,\n",
" \"rationale\": \"The answer does not match the true answer at all.\"\n",
" \"confidence_level\": 0.85\n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's all for today, congrats for following along! 👏"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cookbook",
"language": "python",
"name": "cookbook"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}