3_optimization-design-ptn/03_prompt-optimization/02_summarization.ipynb (650 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Prompt Optimization for Summarization Tasks\n", "---\n", "\n", "Prompt optimization for summarization tasks can be done by applying PromptWizard. The following notebook shows how to optimize a prompt for the task of summarizing a news article. The goal is to create a prompt that effectively instructs the model to generate a concise summary of the article's content." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import promptwizard\n", "from promptwizard.glue.promptopt.instantiate import GluePromptOpt\n", "from promptwizard.glue.promptopt.techniques.common_logic import (\n", " DatasetSpecificProcessing,\n", ")\n", "from promptwizard.glue.common.utils.file import save_jsonlist\n", "from typing import Any\n", "from tqdm import tqdm\n", "from re import compile, findall\n", "import os\n", "from datasets import load_dataset\n", "import yaml" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<br>\n", "\n", "## ๐Ÿงช 1. Prompt Optimization using Your Own dataset\n", "\n", "----\n", "\n", "### Load the dataset for summarization\n", "This dataset is a custom dataset created by the author by crawling Naver News (https://news.naver.com) for the Korean NLP model hands-on.\n", "\n", "- Period: July 1, 2022 - July 10, 2022\n", "- Subject: IT, economics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "num_debug_samples = 10\n", "dataset = load_dataset(\"daekeun-ml/naver-news-summarization-ko\")\n", "dataset[\"train\"] = dataset[\"train\"].shuffle().select(range(num_debug_samples))\n", "dataset[\"test\"] = dataset[\"test\"].shuffle().select(range(num_debug_samples))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SummarizationDataset(DatasetSpecificProcessing):\n", "\n", " def dataset_to_jsonl(self, dataset_jsonl: str, **kwargs: Any) -> None:\n", " def extract_answer_from_output(completion):\n", "\n", " return completion\n", "\n", " examples_set = []\n", "\n", " for _, sample in tqdm(enumerate(kwargs[\"dataset\"]), desc=\"Evaluating samples\"):\n", " example = {\n", " DatasetSpecificProcessing.QUESTION_LITERAL: sample[\n", " \"question\"\n", " ], # question\n", " DatasetSpecificProcessing.ANSWER_WITH_REASON_LITERAL: sample[\n", " \"answer\"\n", " ], # answer\n", " DatasetSpecificProcessing.FINAL_ANSWER_LITERAL: extract_answer_from_output(\n", " sample[\"answer\"]\n", " ), # final_answer\n", " }\n", " examples_set.append(example)\n", "\n", " save_jsonlist(dataset_jsonl, examples_set, \"w\")\n", "\n", " def extract_final_answer(self, llm_output):\n", "\n", " return llm_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "summarization_processor = SummarizationDataset()\n", "\n", "for dataset_type in [\"train\", \"test\"]:\n", " data_list = []\n", " for data in dataset[dataset_type]:\n", " data_list.append({\"question\": data[\"document\"], \"answer\": data[\"summary\"]})\n", " summarization_processor.dataset_to_jsonl(\n", " f\"dataset/news_summarization_{dataset_type}.jsonl\", dataset=data_list\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Update configs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import yaml\n", "\n", "\n", "def update_yaml_file(file_path, config_dict):\n", "\n", " with open(file_path, \"r\") as file:\n", " data = yaml.safe_load(file)\n", "\n", " for field, value in config_dict.items():\n", " data[field] = value\n", "\n", " with open(file_path, \"w\") as file:\n", " yaml.dump(data, file, default_flow_style=False, allow_unicode=True)\n", "\n", " print(\"YAML file updated successfully!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path_to_config = \"configs\"\n", "promptopt_config_path = os.path.join(\n", " path_to_config, \"summarization_promptopt_config.yaml\"\n", ")\n", "setup_config_path = os.path.join(path_to_config, \"setup_config.yaml\")\n", "\n", "# Update promptopt_config.yaml. This file is used to configure the prompt optimization process.\n", "config_dict = {\n", " \"task_description\": \"You are a summarizer. Your task is to summarize the content of the document.\",\n", " \"base_instruction\": \"Write a concise summary of the following. Think step by step to ensure you cover all important points.\",\n", " \"answer_format\": \"Provide a bullet-point summary of the following document, listing the main arguments and supporting evidence in 3-5 concise bullet points. Avoid unnecessary detail and focus on the most important takeaways.\",\n", " \"mutation_rounds\": 2,\n", " \"few_shot_count\": 3,\n", " \"generate_reasoning\": True,\n", " \"mutate_refine_iterations\": 1,\n", "}\n", "update_yaml_file(promptopt_config_path, config_dict)\n", "\n", "# Update setup_config.yaml. This file is used to track the log of the experiment.\n", "config_dict = {\n", " \"experiment_name\": \"summarization\",\n", "}\n", "update_yaml_file(setup_config_path, config_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialize LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dotenv\n", "from langchain.chat_models import init_chat_model\n", "\n", "dotenv.load_dotenv()\n", "\n", "llm = init_chat_model(\"gpt-4o-mini\", model_provider=\"azure_openai\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create an object for calling prompt optimization and inference functionalities" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.insert(0, \"./\")\n", "\n", "import promptwizard\n", "from promptwizard.glue.promptopt.instantiate import GluePromptOpt\n", "from promptwizard.glue.promptopt.techniques.common_logic import (\n", " DatasetSpecificProcessing,\n", ")\n", "\n", "# gp = GluePromptOpt(\n", "# promptopt_config_path, setup_config_path, dataset_jsonl=None, data_processor=None\n", "# )\n", "gp = GluePromptOpt(\n", " promptopt_config_path,\n", " setup_config_path,\n", " dataset_jsonl=\"dataset/news_summarization_train.jsonl\",\n", " data_processor=summarization_processor,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Call prompt optmization function\n", "1. ```use_examples``` can be used when there are training samples and a mixture of real and synthetic in-context examples are required in the final prompt. When set to ```False``` all the in-context examples will be real\n", "2. ```generate_synthetic_examples``` can be used when there are no training samples and we want to generate synthetic examples \n", "3. ```run_without_train_examples``` can be used when there are no training samples and in-context examples are not required in the final prompt \n", "\n", "**Note: It will take a while to run (about 3-4 minutes)**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "best_prompt, expert_profile = gp.get_best_prompt(\n", " use_examples=True,\n", " run_without_train_examples=False,\n", " generate_synthetic_examples=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup the default prompt. Optimized prompt came from the above step." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "\n", "system_prompt = (\n", " \"\"\"You are a summarizer. Your task is to summarize the document.\"\"\"\n", ")\n", "\n", "human_prompt = \"\"\"\n", "The following is a set of documents:\n", "{docs}\n", "Based on this list of docs, please summarize the content of the document.\n", "Answer:\n", "\"\"\"\n", "\n", "optimized_system_prompt = expert_profile\n", "optimized_human_prompt = best_prompt + \"\\n\" + human_prompt\n", "\n", "# Heuristic prompt\n", "map_prompt = ChatPromptTemplate.from_messages(\n", " [(\"system\", system_prompt), (\"human\", human_prompt)],\n", ")\n", "\n", "# Optimized prompt using PromptWizard\n", "optimized_map_prompt = ChatPromptTemplate.from_messages(\n", " [(\"system\", optimized_system_prompt), (\"human\", optimized_human_prompt)],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Also available via the hub: `hub.pull(\"rlm/reduce-prompt\")`\n", "reduce_template = \"\"\"\n", "The following is a set of summaries:\n", "{docs}\n", "Take these and distill it into a final, consolidated summary of the main themes. The summary should include the main keywords.\n", "\"\"\"\n", "reduce_prompt = ChatPromptTemplate([(\"human\", reduce_template)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load documents\n", "\n", "This example simply uses the technical blog post from the Microsoft Technical blog. But you can any document you want to summarize. Please note that documment retrieval for production requires a lot of work." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_community.document_loaders import WebBaseLoader\n", "\n", "url = \"https://www.hankyung.com/article/202504042119i\"\n", "loader = WebBaseLoader(url)\n", "docs = loader.load()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_text_splitters import CharacterTextSplitter\n", "\n", "chunk_size = 1000\n", "text_splitter = CharacterTextSplitter.from_tiktoken_encoder(\n", " chunk_size=chunk_size, chunk_overlap=100\n", ")\n", "split_docs = text_splitter.split_documents(docs)\n", "print(f\"Generated {len(split_docs)} documents.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<br>\n", "\n", "## 2. Build the Graph and Run it\n", "\n", "----" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import operator\n", "from typing import Annotated, List, Literal, TypedDict\n", "\n", "from langchain.chains.combine_documents.reduce import (\n", " acollapse_docs,\n", " split_list_of_docs,\n", ")\n", "from langchain_core.documents import Document\n", "from langgraph.constants import Send\n", "from langgraph.graph import END, START, StateGraph\n", "from functools import partial\n", "\n", "token_max = chunk_size\n", "\n", "\n", "def length_function(documents: List[Document]) -> int:\n", " \"\"\"Get number of tokens for input contents.\"\"\"\n", " return sum(llm.get_num_tokens(doc.page_content) for doc in documents)\n", "\n", "\n", "# This will be the overall state of the main graph.\n", "# It will contain the input document contents, corresponding\n", "# summaries, and a final summary.\n", "class OverallState(TypedDict):\n", " # Notice here we use the operator.add\n", " # This is because we want combine all the summaries we generate\n", " # from individual nodes back into one list - this is essentially\n", " # the \"reduce\" part\n", " contents: List[str]\n", " summaries: Annotated[list, operator.add]\n", " collapsed_summaries: List[Document]\n", " final_summary: str\n", "\n", "\n", "# This will be the state of the node that we will \"map\" all\n", "# documents to in order to generate summaries\n", "class SummaryState(TypedDict):\n", " content: str\n", "\n", "\n", "# Here we generate a summary, given a document\n", "async def generate_summary(state: SummaryState, prompt_template: ChatPromptTemplate):\n", " prompt = prompt_template.invoke(state[\"content\"])\n", " response = await llm.ainvoke(prompt)\n", " return {\"summaries\": [response.content]}\n", "\n", "\n", "# Here we define the logic to map out over the documents\n", "# We will use this an edge in the graph\n", "def map_summaries(state: OverallState):\n", " # We will return a list of `Send` objects\n", " # Each `Send` object consists of the name of a node in the graph\n", " # as well as the state to send to that node\n", " return [\n", " Send(\"generate_summary\", {\"content\": content}) for content in state[\"contents\"]\n", " ]\n", "\n", "\n", "def collect_summaries(state: OverallState):\n", " return {\n", " \"collapsed_summaries\": [Document(summary) for summary in state[\"summaries\"]]\n", " }\n", "\n", "\n", "async def _reduce(input: dict) -> str:\n", " prompt = reduce_prompt.invoke(input)\n", " response = await llm.ainvoke(prompt)\n", " return response.content\n", "\n", "\n", "# Add node to collapse summaries\n", "async def collapse_summaries(state: OverallState):\n", " doc_lists = split_list_of_docs(\n", " state[\"collapsed_summaries\"], length_function, token_max\n", " )\n", " results = []\n", " for doc_list in doc_lists:\n", " results.append(await acollapse_docs(doc_list, _reduce))\n", "\n", " return {\"collapsed_summaries\": results}\n", "\n", "\n", "# This represents a conditional edge in the graph that determines\n", "# if we should collapse the summaries or not\n", "def should_collapse(\n", " state: OverallState,\n", ") -> Literal[\"collapse_summaries\", \"generate_final_summary\"]:\n", " num_tokens = length_function(state[\"collapsed_summaries\"])\n", " if num_tokens > token_max:\n", " return \"collapse_summaries\"\n", " else:\n", " return \"generate_final_summary\"\n", "\n", "\n", "# Here we will generate the final summary\n", "async def generate_final_summary(state: OverallState):\n", " response = await _reduce(state[\"collapsed_summaries\"])\n", " return {\"final_summary\": response}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Case 1. Using the heuristic prompt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Construct the graph\n", "# Nodes:\n", "_generate_summary = partial(generate_summary, prompt_template=map_prompt)\n", "\n", "graph = StateGraph(OverallState)\n", "graph.add_node(\"generate_summary\", _generate_summary)\n", "graph.add_node(\"collect_summaries\", collect_summaries)\n", "graph.add_node(\"collapse_summaries\", collapse_summaries)\n", "graph.add_node(\"generate_final_summary\", generate_final_summary)\n", "\n", "# Edges:\n", "graph.add_conditional_edges(START, map_summaries, [\"generate_summary\"])\n", "graph.add_edge(\"generate_summary\", \"collect_summaries\")\n", "graph.add_conditional_edges(\"collect_summaries\", should_collapse)\n", "graph.add_conditional_edges(\"collapse_summaries\", should_collapse)\n", "graph.add_edge(\"generate_final_summary\", END)\n", "\n", "app = graph.compile()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from azure_genai_utils.graphs import visualize_langgraph\n", "\n", "visualize_langgraph(app, xray=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_core.runnables import RunnableConfig\n", "from azure_genai_utils.messages import random_uuid\n", "\n", "inputs = {\n", " \"contents\": [doc.page_content for doc in split_docs],\n", "}\n", "config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n", "async for step in app.astream(inputs, config):\n", " curr_node = list(step.keys())[0]\n", " print(\"\\n\" + \"=\" * 50)\n", " print(f\"๐Ÿ”„ Node: \\033[1;36m{curr_node}\\033[0m ๐Ÿ”„\")\n", " print(\"- \" * 25)\n", " print(list(step.values()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "outputs = step[\"generate_final_summary\"][\"final_summary\"]\n", "from IPython.display import display, Markdown\n", "\n", "display(Markdown(outputs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Case 2. Using the Optimized prompt\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Construct the graph\n", "# Nodes:\n", "_generate_summary = partial(generate_summary, prompt_template=optimized_map_prompt)\n", "\n", "graph = StateGraph(OverallState)\n", "graph.add_node(\"generate_summary\", _generate_summary)\n", "graph.add_node(\"collect_summaries\", collect_summaries)\n", "graph.add_node(\"collapse_summaries\", collapse_summaries)\n", "graph.add_node(\"generate_final_summary\", generate_final_summary)\n", "\n", "# Edges:\n", "graph.add_conditional_edges(START, map_summaries, [\"generate_summary\"])\n", "graph.add_edge(\"generate_summary\", \"collect_summaries\")\n", "graph.add_conditional_edges(\"collect_summaries\", should_collapse)\n", "graph.add_conditional_edges(\"collapse_summaries\", should_collapse)\n", "graph.add_edge(\"generate_final_summary\", END)\n", "\n", "app = graph.compile()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from langchain_core.runnables import RunnableConfig\n", "from azure_genai_utils.messages import random_uuid\n", "\n", "inputs = {\n", " \"contents\": [doc.page_content for doc in split_docs],\n", "}\n", "config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n", "async for step in app.astream(inputs, config):\n", " curr_node = list(step.keys())[0]\n", " print(\"\\n\" + \"=\" * 50)\n", " print(f\"๐Ÿ”„ Node: \\033[1;36m{curr_node}\\033[0m ๐Ÿ”„\")\n", " print(\"- \" * 25)\n", " print(list(step.values()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that the summary generated by the optimized prompt is better than the one generated by the heuristic prompt. The optimized prompt is more concise and captures the main points of the article more effectively.\n", "Of course, you can improve the heuristic prompt by prompt engineering. But the automated optimized prompt is a good starting point for creating effective prompts for summarization tasks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "outputs = step[\"generate_final_summary\"][\"final_summary\"]\n", "from IPython.display import display, Markdown\n", "\n", "display(Markdown(outputs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Futher work\n", "You can evaluate the performance of the optimized prompt using the test set and Azure AI Evaluation SDK." ] } ], "metadata": { "kernelspec": { "display_name": "py312-dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }