notebooks/ko/structured_generation.ipynb (597 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 구조화된 생성으로 근거 강조 표시가 있는 RAG 시스템 구축하기\n", "_작성자: [Aymeric Roucher](https://huggingface.co/m-ric), 번역: [유용상](https://huggingface.co/4n3mone)_\n", "\n", "**구조화된 생성**(Structured generation)은 LLM 출력이 특정 패턴을 따르도록 강제하는 방법입니다.\n", "\n", "이 방법은 여러 가지 용도로 사용될 수 있습니다:\n", "- ✅ 특정 키가 있는 딕셔너리 출력\n", "- 📏 출력이 N글자 이상이 되도록 보장\n", "- ⚙️ 더 일반적으로, 다운스트림 처리를 위해 출력이 특정 정규 표현식 패턴을 따르도록 강제\n", "- 💡 검색 증강 생성(RAG)에서 답변을 뒷받침하는 소스를 강조 표시\n", "\n", "이 노트북은 마지막 예시를 구체적으로 보여줍니다.\n", "\n", "**➡️ 우리는 답변을 제공할 뿐만 아니라 이 답변의 근거가 되는 스니펫을 강조 표시하는 RAG 시스템을 구축합니다.**\n", "\n", "_RAG에 대한 소개가 필요하다면, [이 쿡북](advanced_rag)을 확인해 보세요._\n", "\n", "이 노트북은 먼저 프롬프트를 통한 구조화된 생성의 단순한 접근 방식을 보여주고 그 한계를 강조한 다음, 더 효율적인 구조화된 생성을 위한 제한된 디코딩(constrained decoding)을 시연합니다.\n", "\n", "이 노트북은 HuggingFace Inference Endpoints를 활용합니다 (예제는 [서버리스](https://huggingface.co/docs/api-inference/quicktour) 엔드포인트를 사용하지만, [전용](https://huggingface.co/docs/inference-endpoints/en/guides/access) 엔드포인트로 변경할 수 있습니다), 또한 [outlines](https://github.com/outlines-dev/outlines)라는 구조화된 텍스트 생성 라이브러리를 사용한 로컬 추론 예제도 보여줍니다." ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "id": "5prqzyu6zyVg" }, "outputs": [], "source": [ "!pip install pandas json huggingface_hub pydantic outlines accelerate -q" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "id": "pxIb4wz0zyVg" }, "outputs": [], "source": [ "import pandas as pd\n", "import json\n", "from huggingface_hub import InferenceClient\n", "\n", "pd.set_option(\"display.max_colwidth\", None)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "id": "8GxOlj0czyVh", "outputId": "7315edac-a7c1-4608-cd55-6366d7e27515" }, "outputs": [ { "data": { "text/plain": [ "' 서울특별시입니다.'" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "repo_id = \"mistralai/Mistral-Nemo-Instruct-2407\"\n", "\n", "llm_client = InferenceClient(model=repo_id, timeout=120)\n", "\n", "# Test your LLM client\n", "llm_client.text_generation(prompt=\"대한민국의 수도는?\", max_new_tokens=50)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 모델에 프롬프트 제공하기\n", "\n", "모델에서 구조화된 출력을 얻으려면, 충분히 성능이 좋은 모델에 적절한 지시사항을 포함한 프롬프트를 제공하면 됩니다. 대부분의 경우 이 방법이 잘 작동할 것입니다.\n", "\n", "이번 경우, 우리는 RAG 모델이 답변뿐만 아니라 신뢰도 점수와 근거가 되는 스니펫도 함께 생성하기를 원합니다.\n", "\n", "이러한 출력을 JSON 형식의 딕셔너리로 생성하면, 나중에 쉽게 처리할 수 있습니다 (여기서는 근거가 되는 스니펫을 강조하여 표시할 예정입니다)." ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [], "source": [ "RELEVANT_CONTEXT = \"\"\"\n", "문서:\n", "\n", "오늘 서울의 날씨가 정말 좋네요.\n", "Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n", "\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "RAG_PROMPT_TEMPLATE_JSON= \"\"\"문서를 기반으로 사용자 쿼리에 응답합니다.\n", "\n", "다음은 문서입니다: {context}\n", "\n", "\n", "답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n", "근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n", "\n", "답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n", "\n", "Answer:\n", "{{\n", " “answer\": 정답 문장,\n", " “confidence_score\": 신뢰도 점수,\n", " “source_snippets\": [“근거_1”, “근거_2”, ...]\n", "}}\n", "End of answer.\n", "\n", "이제 시작하세요!\n", "다음은 사용자 질문입니다: {user_query}.\n", "Answer:\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "USER_QUERY = \"Transformers에서 정지 시퀀스를 어떻게 정의하나요?\"" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "id": "QIrMKgBzzyVi", "outputId": "a4c92c0b-ed15-43aa-82a3-8ac23c28f172" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "문서를 기반으로 사용자 쿼리에 응답합니다.\n", "\n", "다음은 문서입니다: \n", "문서:\n", "\n", "오늘 서울의 날씨가 정말 좋네요.\n", "Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\n", "\n", "\n", "\n", "\n", "답변을 JSON 형식으로 제공하고, 답변의 직접적 근거가 된 문서의 모든 관련 짧은 소스 스니펫과 신뢰도 점수를 0에서 1 사이의 부동 소수점으로 제공해야 합니다.\n", "근거 스니펫은 전체 문장이 아닌 기껏해야 몇 단어 정도로 매우 짧아야 합니다! 그리고 문맥에서 정확히 동일한 문구와 철자를 사용하여 추출해야 합니다.\n", "\n", "답변은 다음과 같이 작성해야 하며, “Answer:” 및 “End of answer.” 를 포함해야 합니다.\n", "\n", "Answer:\n", "{\n", " “answer\": 정답 문장,\n", " “confidence_score\": 신뢰도 점수,\n", " “source_snippets\": [“근거_1”, “근거_2”, ...]\n", "}\n", "End of answer.\n", "\n", "이제 시작하세요!\n", "다음은 사용자 질문입니다: Transformers에서 정지 시퀀스를 어떻게 정의하나요?.\n", "Answer:\n", "\n" ] } ], "source": [ "prompt = RAG_PROMPT_TEMPLATE_JSON.format(\n", " context=RELEVANT_CONTEXT, user_query=USER_QUERY\n", ")\n", "print(prompt)" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "id": "JZtnTrSqzyVi", "outputId": "83295148-21db-4cdf-d557-491d7c457358" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": \"Transformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\",\n", " \"confidence_score\": 0.95,\n", " \"source_snippets\": [\"정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\"]\n", "}\n", "\n" ] } ], "source": [ "answer = llm_client.text_generation(\n", " prompt,\n", " max_new_tokens=256,\n", ")\n", "\n", "answer = answer.split(\"End of answer.\")[0]\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LLM의 출력은 딕셔너리의 문자열 표현입니다. 따라서 `literal_eval`을 사용하여 이를 딕셔너리로 로드합시다." ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "id": "sadeCc1JzyVj" }, "outputs": [], "source": [ "from ast import literal_eval\n", "\n", "parsed_answer = literal_eval(answer)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "id": "lPubGIpFzyVj", "outputId": "7f458548-5f0e-40dd-acd4-91897fc3f737" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Answer: \u001b[1;32mTransformers에서 정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n", "\n", "\n", " ========== Source documents ==========\n", "\n", "문서:\n", "\n", "오늘 서울의 날씨가 정말 좋네요.\n", "Transformers에서 \u001b[1;32m정지 시퀀스를 정의하려면 파이프라인 또는 모델에 stop_sequence 인수를 전달해야 합니다.\u001b[0m\n", "\n", "\n" ] } ], "source": [ "def highlight(s):\n", " return \"\\x1b[1;32m\" + s + \"\\x1b[0m\"\n", "\n", "\n", "def print_results(answer, source_text, highlight_snippets):\n", " print(\"Answer:\", highlight(answer))\n", " print(\"\\n\\n\", \"=\" * 10 + \" Source documents \" + \"=\" * 10)\n", " for snippet in highlight_snippets:\n", " source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))\n", " print(source_text)\n", "\n", "\n", "print_results(\n", " parsed_answer[\"answer\"], RELEVANT_CONTEXT, parsed_answer[\"source_snippets\"]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "잘 작동합니다! 🥳\n", "\n", "하지만 성능이 낮은 모델을 사용하는 경우는 어떨까요?\n", "\n", "성능이 떨어지는 모델의 불안정한 출력을 시뮬레이션하기 위해, temperature 값을 높여보겠습니다." ] }, { "cell_type": "code", "execution_count": 76, "metadata": { "id": "eNWhbK0KzyVj", "outputId": "6327cdb6-7f8b-40c6-cf32-546dff51f6e8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": adjectistiques Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор reminding frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust St I With 모 clans Eddy Bindingtsoke funeral Stefano authenticitatcontent。\n", "\n", "적으로ებულიização finnotes fins witCamera 하나 ls Metallurne couleur platinum/c وأنت textarea Golfyyzuhalten assume prog_reset\"Piagn Ameth amivio COR '',\n", "ze Columbia padchart\": Poul?\"\n", "\n", " φsin den Qu tiendas Mister�cling tercero política’avenir emploi banque inertکا …\n", "anic lucommon-contagsbor ruvisending frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth masses frustPolit lMer Maria Banco Comambique-howiktenल्ल 없을Ela Nal realisticINTEn обор музы inférieurke Zendaya alguna7 Mons ram incColumn Orth manages Richie HackAUcasismo<< fpsTIvlOcriptive Ou Tam psycho-Kinsic Serum SecurityülY on Hazard SautéFust\n" ] } ], "source": [ "answer = llm_client.text_generation(\n", " prompt,\n", " max_new_tokens=250,\n", " temperature=1.6,\n", " return_full_text=False,\n", ")\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "출력이 올바른 JSON 형식조차 아닌 것을 확인할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 👉 제한된 디코딩(Constrained decoding)\n", "\n", "JSON 출력을 강제하기 위해, 우리는 **제한된 디코딩**을 사용해야 합니다. 여기서 LLM이 **문법**이라고 불리는 일련의 규칙에 맞는 토큰만 출력하도록 강제합니다.\n", "\n", "이 문법은 Pydantic 모델, JSON 스키마 또는 정규 표현식을 사용하여 정의할 수 있습니다. 그러면 AI는 지정된 문법에 맞는 응답을 생성합니다.\n", "\n", "예를 들어, 여기서는 [Pydantic 타입](https://docs.pydantic.dev/latest/api/types/)을 따릅니다." ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "id": "7NQAnQ7hzyVj" }, "outputs": [], "source": [ "from pydantic import BaseModel, confloat, StringConstraints\n", "from typing import List, Annotated\n", "\n", "\n", "class AnswerWithSnippets(BaseModel):\n", " answer: Annotated[str, StringConstraints(min_length=10, max_length=100)]\n", " confidence: Annotated[float, confloat(ge=0.0, le=1.0)]\n", " source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]" ] }, { "cell_type": "markdown", "metadata": { "id": "Xa-6v1U9zyVj" }, "source": [ "생성된 스키마가 요구 사항을 올바르게 나타내는지 확인해 보세요." ] }, { "cell_type": "code", "execution_count": 78, "metadata": { "id": "gInE3OtqzyVj", "outputId": "f9cdb85c-390e-458c-f1b1-28853f947a0e" }, "outputs": [ { "data": { "text/plain": [ "{'properties': {'answer': {'maxLength': 100,\n", " 'minLength': 10,\n", " 'title': 'Answer',\n", " 'type': 'string'},\n", " 'confidence': {'title': 'Confidence', 'type': 'number'},\n", " 'source_snippets': {'items': {'maxLength': 30, 'type': 'string'},\n", " 'title': 'Source Snippets',\n", " 'type': 'array'}},\n", " 'required': ['answer', 'confidence', 'source_snippets'],\n", " 'title': 'AnswerWithSnippets',\n", " 'type': 'object'}" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "AnswerWithSnippets.schema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "클라이언트의 `text_generation` 메서드를 사용하거나 `post` 메서드를 사용할 수 있습니다." ] }, { "cell_type": "code", "execution_count": 79, "metadata": { "id": "NJW3Op7czyVj", "outputId": "c0d85a5e-a1ea-4332-d2eb-6643ebd80740" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\n", " \"answer\": \" neces恨bay внеpok Archives-Common Propsogs’organpern 공격forschfläche elicous neces恨bay внеpok món-�\",\"confidence\": 1,\"source_snippets\": [\"Washington Roman Humналеualion\", \"_styleImplementedAugust lire\",\n", " \"\"]\n", "\n", " }\n", "{\n", " \"answer\": \" بخopuerto կար因數 kavuts mi Firefox Penguins er sdபெர erinnert publiée 물리 DK\\({}^{\\ Cis بخopuerto կար因數\"\n", ",\n", " \"confidence\": 0.7825484027713585\n", ",\n", " \"source_snippets\": [\n", "\n", "\"Transformerграни moisady отгaನ\", \", migrations ceproductionautal\",\n", "\"Listeners accelerating loocae\"\n", "]\n", "}\n" ] } ], "source": [ "# Using text_generation\n", "answer = llm_client.text_generation(\n", " prompt,\n", " grammar={\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n", " max_new_tokens=250,\n", " temperature=1.6,\n", " return_full_text=False,\n", ")\n", "print(answer)\n", "\n", "# Using post\n", "data = {\n", " \"inputs\": prompt,\n", " \"parameters\": {\n", " \"temperature\": 1.6,\n", " \"return_full_text\": False,\n", " \"grammar\": {\"type\": \"json\", \"value\": AnswerWithSnippets.schema()},\n", " \"max_new_tokens\": 250,\n", " },\n", "}\n", "answer = json.loads(llm_client.post(json=data))[0][\"generated_text\"]\n", "print(answer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "✅ 높은 temperature 설정으로 인해 답변 내용은 여전히 말이 되지 않지만, 생성된 출력 텍스트는 이제 우리가 문법에서 정의한 정확한 키와 자료형을 가진 올바른 JSON 형식입니다!\n", "\n", "이제 이 출력물을 추가 처리를 위해 파싱할 수 있습니다." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Outlines를 사용해서 로컬 환경에서 문법 활용하기\n", "\n", "[Outlines](https://github.com/outlines-dev/outlines/)는 Hugging Face의 Inference API에서 출력 생성을 제한하기 위해 내부적으로 실행되는 라이브러리입니다. 이를 로컬 환경에서도 사용할 수 있습니다.\n", "\n", "이 라이브러리는 [로짓(logits)에 편향(bias)을 적용하는 방식](https://github.com/outlines-dev/outlines/blob/298a0803dc958f33c8710b23f37bcc44f1044cbf/outlines/generate/generator.py#L143)으로 작동하여, 사용자가 정의한 제약 조건에 부합하는 선택지만 강제로 선택되도록 합니다." ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'{\"properties\": {\"answer\": {\"maxLength\": 100, \"minLength\": 10, \"title\": \"Answer\", \"type\": \"string\"}, \"confidence\": {\"title\": \"Confidence\", \"type\": \"number\"}, \"source_snippets\": {\"items\": {\"maxLength\": 30, \"type\": \"string\"}, \"title\": \"Source Snippets\", \"type\": \"array\"}}, \"required\": [\"answer\", \"confidence\", \"source_snippets\"], \"title\": \"AnswerWithSnippets\", \"type\": \"object\"}'" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "schema_as_str" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HNb1UeZSzyVk" }, "outputs": [], "source": [ "import outlines\n", "\n", "repo_id = \"Qwen/Qwen2-7B-Instruct\"\n", "# 로컬에서 모델 로드하기\n", "model = outlines.models.transformers(repo_id)\n", "\n", "schema_as_str = json.dumps(AnswerWithSnippets.schema())\n", "\n", "generator = outlines.generate.json(model, schema_as_str)\n", "\n", "# Use the `generator` to sample an output from the model\n", "result = generator(prompt)\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "제약 생성(constrained generation)을 사용하여 [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index)를 활용할 수도 있습니다 (자세한 내용과 예시는 [문서](https://huggingface.co/docs/text-generation-inference/en/conceptual/guidance)를 참조하세요).\n", "\n", "지금까지 우리는 특정 RAG 사용 사례를 보여주었지만, 제약 생성은 그 이상으로 많은 도움이 됩니다.\n", "\n", "예를 들어, [LLM judge](llm_judge) 워크플로우에서도 제약 생성을 사용하여 다음과 같은 JSON을 출력할 수 있습니다:\n", "```py\n", "{\n", " \"score\": 1,\n", " \"rationale\": \"The answer does not match the true answer at all.\",\n", " \"confidence_level\": 0.85\n", "}\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "fEhBMgK4zyVk" }, "source": [ "오늘은 여기까지입니다. 끝까지 따라와 주셔서 감사드립니다! 👏" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 0 }