notebooks/ko/structured_generation.ipynb (597 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 구조화된 생성으로 근거 강조 표시가 있는 RAG 시스템 구축하기\n",
"_작성자: [Aymeric Roucher](https://huggingface.co/m-ric), 번역: [유용상](https://huggingface.co/4n3mone)_\n",
"\n",
"**구조화된 생성**(Structured generation)은 LLM 출력이 특정 패턴을 따르도록 강제하는 방법입니다.\n",
"\n",
"이 방법은 여러 가지 용도로 사용될 수 있습니다:\n",
"- ✅ 특정 키가 있는 딕셔너리 출력\n",
"- 📏 출력이 N글자 이상이 되도록 보장\n",
"- ⚙️ 더 일반적으로, 다운스트림 처리를 위해 출력이 특정 정규 표현식 패턴을 따르도록 강제\n",
"- 💡 검색 증강 생성(RAG)에서 답변을 뒷받침하는 소스를 강조 표시\n",
"\n",
"이 노트북은 마지막 예시를 구체적으로 보여줍니다.\n",
"\n",
"**➡️ 우리는 답변을 제공할 뿐만 아니라 이 답변의 근거가 되는 스니펫을 강조 표시하는 RAG 시스템을 구축합니다.**\n",
"\n",
"_RAG에 대한 소개가 필요하다면, [이 쿡북](advanced_rag)을 확인해 보세요._\n",
"\n",
"이 노트북은 먼저 프롬프트를 통한 구조화된 생성의 단순한 접근 방식을 보여주고 그 한계를 강조한 다음, 더 효율적인 구조화된 생성을 위한 제한된 디코딩(constrained decoding)을 시연합니다.\n",
"\n",
"이 노트북은 HuggingFace Inference Endpoints를 활용합니다 (예제는 [서버리스](https://huggingface.co/docs/api-inference/quicktour) 엔드포인트를 사용하지만, [전용](https://huggingface.co/docs/inference-endpoints/en/guides/access) 엔드포인트로 변경할 수 있습니다), 또한 [outlines](https://github.com/outlines-dev/outlines)라는 구조화된 텍스트 생성 라이브러리를 사용한 로컬 추론 예제도 보여줍니다."
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"id": "5prqzyu6zyVg"
},
"outputs": [],
"source": [
"!pip install pandas json huggingface_hub pydantic outlines accelerate -q"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"id": "pxIb4wz0zyVg"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import json\n",
"from huggingface_hub import InferenceClient\n",
"\n",
"pd.set_option(\"display.max_colwidth\", None)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"id": "8GxOlj0czyVh",
"outputId": "7315edac-a7c1-4608-cd55-6366d7e27515"
},
"outputs": [
{
"data": {
"text/plain": [
"' 서울특별시입니다.'"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"repo_id = \"mistralai/Mistral-Nemo-Instruct-2407\"\n",
"\n",
"llm_client = InferenceClient(model=repo_id, timeout=120)\n",
"\n",
"# Test your LLM client\n",
"llm_client.text_generation(prompt=\"대한민국의 수도는?\", max_new_tokens=50)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 모델에 프롬프트 제공하기\n",
"\n",
"모델에서 구조화된 출력을 얻으려면, 충분히 성능이 좋은 모델에 적절한 지시사항을 포함한 프롬프트를 제공하면 됩니다. 대부분의 경우 이 방법이 잘 작동할 것입니다.\n",
"\n",
"이번 경우, 우리는 RAG 모델이 답변뿐만 아니라 신뢰도 점수와 근거가 되는 스니펫도 함께 생성하기를 원합니다.\n",
"\n",
"이러한 출력을 JSON 형식의 딕셔너리로 생성하면, 나중에 쉽게 처리할 수 있습니다 (여기서는 근거가 되는 스니펫을 강조하여 표시할 예정입니다)."
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"RELEVANT_CONTEXT = \"\"\"\n",
"문서:\n",
"\n",
"오늘 서울의 날씨가 정말 좋네요.\n",
"Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"RAG_PROMPT_TEMPLATE_JSON= \"\"\"문서를 기반으로 사용자 쿼리에 응답합니다.\n",
"\n",
"다음은 문서입니다: {context}\n",
"\n",
"\n",
"답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n",
"근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n",
"\n",
"답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n",
"\n",
"Answer:\n",
"{{\n",
" “answer\": 정답 문장,\n",
" “confidence_score\": 신뢰도 점수,\n",
" “source_snippets\": [“근거_1”, “근거_2”, ...]\n",
"}}\n",
"End of answer.\n",
"\n",
"이제 시작하세요!\n",
"다음은 사용자 질문입니다: {user_query}.\n",
"Answer:\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"USER_QUERY = \"Transformers에서 정지 시퀀스를 어떻게 정의하나요?\""
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"id": "QIrMKgBzzyVi",
"outputId": "a4c92c0b-ed15-43aa-82a3-8ac23c28f172"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"문서를 기반으로 사용자 쿼리에 응답합니다.\n",
"\n",
"다음은 문서입니다: \n",
"문서:\n",
"\n",
"오늘 서울의 날씨가 정말 좋네요.\n",
"Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n",
"\n",
"\n",
"\n",
"\n",
"답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n",
"근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n",
"\n",
"답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n",
"\n",
"Answer:\n",
"{\n",
" “answer\": 정답 문장,\n",
" “confidence_score\": 신뢰도 점수,\n",
" “source_snippets\": [“근거_1”, “근거_2”, ...]\n",
"}\n",
"End of answer.\n",
"\n",
"이제 시작하세요!\n",
"다음은 사용자 질문입니다: Transformers에서 정지 시퀀스를 어떻게 정의하나요?.\n",
"Answer:\n",
"\n"
]
}
],
"source": [
"prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n",
" context=RELEVANT_CONTEXT, user_query=USER_QUERY\n",
")\n",
"print(prompt)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"id": "JZtnTrSqzyVi",
"outputId": "83295148-21db-4cdf-d557-491d7c457358"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": \"Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\",\n",
" \"confidence_score\": 0.95,\n",
" \"source_snippets\": [\"정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\"]\n",
"}\n",
"\n"
]
}
],
"source": [
"answer = llm_client.text_generation(\n",
" prompt,\n",
" max_new_tokens=256,\n",
")\n",
"\n",
"answer = answer.split(\"End of answer.\")[0]\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"LLM의 출력은 딕셔너리의 문자열 표현입니다. 따라서 `literal_eval`을 사용하여 이를 딕셔너리로 로드합시다."
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"id": "sadeCc1JzyVj"
},
"outputs": [],
"source": [
"from ast import literal_eval\n",
"\n",
"parsed_answer = literal_eval(answer)"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"id": "lPubGIpFzyVj",
"outputId": "7f458548-5f0e-40dd-acd4-91897fc3f737"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Answer: \u001b[1;32mTransformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n",
"\n",
"\n",
" ========== Source documents ==========\n",
"\n",
"문서:\n",
"\n",
"오늘 서울의 날씨가 정말 좋네요.\n",
"Transformers에서 \u001b[1;32m정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n",
"\n",
"\n"
]
}
],
"source": [
"def highlight(s):\n",
" return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n",
"\n",
"\n",
"def print_results(answer, source_text, highlight_snippets):\n",
" print(\"Answer:\", highlight(answer))\n",
" print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n",
" for snippet in highlight_snippets:\n",
" source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n",
" print(source_text)\n",
"\n",
"\n",
"print_results(\n",
" parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"잘 작동합니다! 🥳\n",
"\n",
"하지만 성능이 낮은 모델을 사용하는 경우는 어떨까요?\n",
"\n",
"성능이 떨어지는 모델의 불안정한 출력을 시뮬레이션하기 위해, temperature 값을 높여보겠습니다."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"id": "eNWhbK0KzyVj",
"outputId": "6327cdb6-7f8b-40c6-cf32-546dff51f6e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": adjectistiques Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор reminding frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust St I With 모 clans Eddy Bindingtsoke funeral Stefano authenticitatcontent。\n",
"\n",
"적으로ებულიização finnotes fins witCamera 하나 ls Metallurne couleur platinum/c وأنت textarea Golfyyzuhalten assume prog_reset\"Piagn Ameth amivio COR '',\n",
"ze Columbia padchart\": Poul?\"\n",
"\n",
" φsin den Qu tiendas Mister�cling tercero política’avenir emploi banque inertکا …\n",
"anic lucommon-contagsbor ruvisending frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth masses frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust\n"
]
}
],
"source": [
"answer = llm_client.text_generation(\n",
" prompt,\n",
" max_new_tokens=250,\n",
" temperature=1.6,\n",
" return_full_text=False,\n",
")\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"출력이 올바른 JSON 형식조차 아닌 것을 확인할 수 있습니다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 👉 제한된 디코딩(Constrained decoding)\n",
"\n",
"JSON 출력을 강제하기 위해, 우리는 **제한된 디코딩**을 사용해야 합니다. 여기서 LLM이 **문법**이라고 불리는 일련의 규칙에 맞는 토큰만 출력하도록 강제합니다.\n",
"\n",
"이 문법은 Pydantic 모델, JSON 스키마 또는 정규 표현식을 사용하여 정의할 수 있습니다. 그러면 AI는 지정된 문법에 맞는 응답을 생성합니다.\n",
"\n",
"예를 들어, 여기서는 [Pydantic 타입](https://docs.pydantic.dev/latest/api/types/)을 따릅니다."
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "7NQAnQ7hzyVj"
},
"outputs": [],
"source": [
"from pydantic import BaseModel, confloat, StringConstraints\n",
"from typing import List, Annotated\n",
"\n",
"\n",
"class AnswerWithSnippets(BaseModel):\n",
" answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n",
" confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n",
" source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xa-6v1U9zyVj"
},
"source": [
"생성된 스키마가 요구 사항을 올바르게 나타내는지 확인해 보세요."
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"id": "gInE3OtqzyVj",
"outputId": "f9cdb85c-390e-458c-f1b1-28853f947a0e"
},
"outputs": [
{
"data": {
"text/plain": [
"{'properties': {'answer': {'maxLength': 100,\n",
" 'minLength': 10,\n",
" 'title': 'Answer',\n",
" 'type': 'string'},\n",
" 'confidence': {'title': 'Confidence', 'type': 'number'},\n",
" 'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n",
" 'title': 'Source Snippets',\n",
" 'type': 'array'}},\n",
" 'required': ['answer', 'confidence', 'source_snippets'],\n",
" 'title': 'AnswerWithSnippets',\n",
" 'type': 'object'}"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AnswerWithSnippets.schema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"클라이언트의 `text_generation` 메서드를 사용하거나 `post` 메서드를 사용할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"id": "NJW3Op7czyVj",
"outputId": "c0d85a5e-a1ea-4332-d2eb-6643ebd80740"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"answer\": \" neces恨bay внеpok Archives-Common Propsogs’organpern 공격forschfläche elicous neces恨bay внеpok món-�\",\"confidence\": 1,\"source_snippets\": [\"Washington Roman Humналеualion\", \"_styleImplementedAugust lire\",\n",
" \"\"]\n",
"\n",
" }\n",
"{\n",
" \"answer\": \" بخopuerto կար因數 kavuts mi Firefox Penguins er sdபெர erinnert publiée 물리 DK\\({}^{\\ Cis بخopuerto կար因數\"\n",
",\n",
" \"confidence\": 0.7825484027713585\n",
",\n",
" \"source_snippets\": [\n",
"\n",
"\"Transformerграни moisady отгaನ\", \", migrations ceproductionautal\",\n",
"\"Listeners accelerating loocae\"\n",
"]\n",
"}\n"
]
}
],
"source": [
"# Using text_generation\n",
"answer = llm_client.text_generation(\n",
" prompt,\n",
" grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
" max_new_tokens=250,\n",
" temperature=1.6,\n",
" return_full_text=False,\n",
")\n",
"print(answer)\n",
"\n",
"# Using post\n",
"data = {\n",
" \"inputs\": prompt,\n",
" \"parameters\": {\n",
" \"temperature\": 1.6,\n",
" \"return_full_text\": False,\n",
" \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n",
" \"max_new_tokens\": 250,\n",
" },\n",
"}\n",
"answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n",
"print(answer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"✅ 높은 temperature 설정으로 인해 답변 내용은 여전히 말이 되지 않지만, 생성된 출력 텍스트는 이제 우리가 문법에서 정의한 정확한 키와 자료형을 가진 올바른 JSON 형식입니다!\n",
"\n",
"이제 이 출력물을 추가 처리를 위해 파싱할 수 있습니다."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Outlines를 사용해서 로컬 환경에서 문법 활용하기\n",
"\n",
"[Outlines](https://github.com/outlines-dev/outlines/)는 Hugging Face의 Inference API에서 출력 생성을 제한하기 위해 내부적으로 실행되는 라이브러리입니다. 이를 로컬 환경에서도 사용할 수 있습니다.\n",
"\n",
"이 라이브러리는 [로짓(logits)에 편향(bias)을 적용하는 방식](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143)으로 작동하여, 사용자가 정의한 제약 조건에 부합하는 선택지만 강제로 선택되도록 합니다."
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'{\"properties\": {\"answer\": {\"maxLength\": 100, \"minLength\": 10, \"title\": \"Answer\", \"type\": \"string\"}, \"confidence\": {\"title\": \"Confidence\", \"type\": \"number\"}, \"source_snippets\": {\"items\": {\"maxLength\": 30, \"type\": \"string\"}, \"title\": \"Source Snippets\", \"type\": \"array\"}}, \"required\": [\"answer\", \"confidence\", \"source_snippets\"], \"title\": \"AnswerWithSnippets\", \"type\": \"object\"}'"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"schema_as_str"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HNb1UeZSzyVk"
},
"outputs": [],
"source": [
"import outlines\n",
"\n",
"repo_id = \"Qwen/Qwen2-7B-Instruct\"\n",
"# 로컬에서 모델 로드하기\n",
"model = outlines.models.transformers(repo_id)\n",
"\n",
"schema_as_str = json.dumps(AnswerWithSnippets.schema())\n",
"\n",
"generator = outlines.generate.json(model, schema_as_str)\n",
"\n",
"# Use the `generator` to sample an output from the model\n",
"result = generator(prompt)\n",
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"제약 생성(constrained generation)을 사용하여 [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index)를 활용할 수도 있습니다 (자세한 내용과 예시는 [문서](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance)를 참조하세요).\n",
"\n",
"지금까지 우리는 특정 RAG 사용 사례를 보여주었지만, 제약 생성은 그 이상으로 많은 도움이 됩니다.\n",
"\n",
"예를 들어, [LLM judge](llm_judge) 워크플로우에서도 제약 생성을 사용하여 다음과 같은 JSON을 출력할 수 있습니다:\n",
"```py\n",
"{\n",
" \"score\": 1,\n",
" \"rationale\": \"The answer does not match the true answer at all.\",\n",
" \"confidence_level\": 0.85\n",
"}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fEhBMgK4zyVk"
},
"source": [
"오늘은 여기까지입니다. 끝까지 따라와 주셔서 감사드립니다! 👏"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 0
}