search/bulk-question-answering/bulk_question_answering.ipynb (506 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "359697d5" }, "source": [ "# Bulk Question Answering with Vertex AI Search\n", "\n", "Answer questions from a CSV using a Vertex AI Search data store.\n", "\n", "<table align=\"left\">\n", "\n", " <td>\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\">\n", " <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Run in Colab\n", " </a>\n", " </td>\n", " <td>\n", " <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\">\n", " <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n", " </a>\n", " </td>\n", " <td>\n", " <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/search/bulk-question-answering/bulk_question_answering.ipynb\">\n", " <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n", " </a>\n", " </td>\n", "</table>\n", "\n", "<div style=\"clear: both;\"></div>\n", "\n", "<b>Share to:</b>\n", "\n", "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n", "</a>\n", "\n", "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n", "</a>\n", "\n", "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n", "</a>\n", "\n", "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n", "</a>\n", "\n", "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/search/bulk-question-answering/bulk_question_answering.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n", "</a> " ] }, { "cell_type": "markdown", "metadata": { "id": "0c333685c070" }, "source": [ "| | |\n", "|-|-|\n", "|Author(s) | [Ruchika Kharwar](https://github.com/rasalt), [Holt Skinner](https://github.com/holtskinner) |" ] }, { "cell_type": "markdown", "metadata": { "id": "pQbOO-Lc-2-7" }, "source": [ "## Install pre-requisites\n", "\n", "If running in Colab install the pre-requisites into the runtime. Otherwise it is assumed that the notebook is running in Vertex AI Workbench. In that case it is recommended to install the pre-requisites from a terminal using the `--user` option." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ohPUPez8imvE" }, "outputs": [], "source": [ "%pip install google-cloud-discoveryengine google-auth pandas --upgrade --user -q" ] }, { "cell_type": "markdown", "metadata": { "id": "R5Xep4W9lq-Z" }, "source": [ "### Restart current runtime\n", "\n", "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which will restart the current kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XRvKdaPDTznN" }, "outputs": [], "source": [ "# Restart kernel after installs so that your environment can access the new packages\n", "import IPython\n", "\n", "app = IPython.Application.instance()\n", "app.kernel.do_shutdown(True)" ] }, { "cell_type": "markdown", "metadata": { "id": "SbmM4z7FOBpM" }, "source": [ "<div class=\"alert alert-block alert-warning\">\n", "<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n", "</div>\n" ] }, { "cell_type": "markdown", "metadata": { "id": "8vEczuYo_r-g" }, "source": [ "## Authenticate\n", "\n", "If running in Colab authenticate with `google.colab.google.auth` otherwise assume that running on Vertex AI Workbench." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "loTfn0KniwB2" }, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " from google.colab import auth as google_auth\n", "\n", " google_auth.authenticate_user()" ] }, { "cell_type": "markdown", "metadata": { "id": "c755a26feb44" }, "source": [ "## Definitions for the index values\n", "\n", "- \"Query\"\n", "- \"Golden Doc\"\n", "- \"Golden Doc Page Number\"\n", "- \"Golden Answer\"\n", "- \"Top 5 Docs\"\n", "- \"Top 5 extractive answers\"\n", "- \"Top 5 extractive segments\"\n", "- \"Answer / Summary\"" ] }, { "cell_type": "markdown", "metadata": { "id": "3094caeee89d" }, "source": [ "# Import libraries" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "78748e023ce0" }, "outputs": [], "source": [ "from google.api_core.client_options import ClientOptions\n", "from google.cloud import discoveryengine_v1beta as discoveryengine\n", "import pandas as pd" ] }, { "cell_type": "markdown", "metadata": { "id": "66ff7614" }, "source": [ "### Set the following constants to reflect your environment\n", "* The queries used in the examples here relate to a GCS bucket containing Alphabet investor PDFs, but these should be customized to your own data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "uVxFSrppK8Oy" }, "outputs": [], "source": [ "PROJECT_ID = \"YOUR_PROJECT_ID\" # @param {type:\"string\"}\n", "LOCATION = \"global\" # @param {type:\"string\"}\n", "DATA_STORE_ID = \"YOUR_DATA_STORE_ID\" # @param {type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "a950641ee9e4" }, "source": [ "## Function to search the Vertex AI Search data store" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "34b1bac17ff8" }, "outputs": [], "source": [ "def search_data_store(\n", " project_id: str,\n", " location: str,\n", " data_store_id: str,\n", " search_query: str,\n", ") -> discoveryengine.SearchResponse:\n", " # For more information, refer to:\n", " # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store\n", " client_options = (\n", " ClientOptions(api_endpoint=f\"{location}-discoveryengine.googleapis.com\")\n", " if location != \"global\"\n", " else None\n", " )\n", "\n", " # Create a client\n", " client = discoveryengine.SearchServiceClient(client_options=client_options)\n", "\n", " # The full resource name of the search engine serving config\n", " # e.g. projects/{project_id}/locations/{location}/dataStores/{data_store_id}/servingConfigs/{serving_config_id}\n", " serving_config = client.serving_config_path(\n", " project=project_id,\n", " location=location,\n", " data_store=data_store_id,\n", " serving_config=\"default_config\",\n", " )\n", "\n", " # Optional: Configuration options for search\n", " # Refer to the `ContentSearchSpec` reference for all supported fields:\n", " # https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1.types.SearchRequest.ContentSearchSpec\n", " content_search_spec = discoveryengine.SearchRequest.ContentSearchSpec(\n", " # For information about snippets, refer to:\n", " # https://cloud.google.com/generative-ai-app-builder/docs/snippets\n", " snippet_spec=discoveryengine.SearchRequest.ContentSearchSpec.SnippetSpec(\n", " return_snippet=True\n", " ),\n", " extractive_content_spec=discoveryengine.SearchRequest.ContentSearchSpec.ExtractiveContentSpec(\n", " max_extractive_answer_count=5,\n", " max_extractive_segment_count=1,\n", " ),\n", " # For information about search summaries, refer to:\n", " # https://cloud.google.com/generative-ai-app-builder/docs/get-search-summaries\n", " summary_spec=discoveryengine.SearchRequest.ContentSearchSpec.SummarySpec(\n", " summary_result_count=5,\n", " include_citations=True,\n", " ignore_adversarial_query=False,\n", " ignore_non_summary_seeking_query=False,\n", " ),\n", " )\n", "\n", " # Refer to the `SearchRequest` reference for all supported fields:\n", " # https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1.types.SearchRequest\n", " request = discoveryengine.SearchRequest(\n", " serving_config=serving_config,\n", " query=search_query,\n", " page_size=5,\n", " content_search_spec=content_search_spec,\n", " query_expansion_spec=discoveryengine.SearchRequest.QueryExpansionSpec(\n", " condition=discoveryengine.SearchRequest.QueryExpansionSpec.Condition.AUTO,\n", " ),\n", " spell_correction_spec=discoveryengine.SearchRequest.SpellCorrectionSpec(\n", " mode=discoveryengine.SearchRequest.SpellCorrectionSpec.Mode.AUTO\n", " ),\n", " )\n", "\n", " response = client.search(request)\n", " return response" ] }, { "cell_type": "markdown", "metadata": { "id": "5b9d2cee5c09" }, "source": [ "# Function to load results into DataFrame" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "id": "1d38d55aa18c" }, "outputs": [], "source": [ "def answer_questions(\n", " row, project_id: str, location: str, data_store_id: str, top_n: int = 5\n", ") -> None:\n", " \"\"\"This function returns the top 5 docs, extractive segments, answers\"\"\"\n", " # Perform search with Query\n", " response = search_data_store(project_id, location, data_store_id, row[\"Query\"])\n", "\n", " row[\"Answer / Summary\"] = response.summary.summary_text\n", "\n", " top5docs, top5answers, top5segments = [], [], []\n", " ext_ans_cnt, ext_seg_cnt = 0, 0\n", "\n", " for result in response.results:\n", " doc_data = getattr(result.document, \"derived_struct_data\", None)\n", " if not doc_data:\n", " continue\n", "\n", " # Process extractive answers\n", " for chunk in doc_data.get(\"extractive_answers\", []):\n", " content = chunk.get(\"content\", \"\").replace(\"\\n\", \"\")\n", " top5answers.append(content)\n", " top5docs.append(\n", " f\"Doc: {doc_data.get('link', '')} Page: {chunk.get('pageNumber', '')}\"\n", " )\n", " ext_ans_cnt += 1\n", "\n", " # Process extractive segments\n", " for chunk in doc_data.get(\"extractive_segments\", []):\n", " data = chunk.get(\"content\", \"\").replace(\"\\n\", \"\")\n", " top5segments.append(data)\n", " ext_seg_cnt += 1\n", "\n", " if ext_ans_cnt >= top_n and ext_seg_cnt >= top_n:\n", " break\n", "\n", " row[\"Top 5 Docs\"] = \"\\n\\n\".join(top5docs)\n", " row[\"Top 5 extractive answers\"] = \"\\n\\n\".join(top5answers)\n", " row[\"Top 5 extractive segments\"] = \"\\n\\n\".join(top5segments)" ] }, { "cell_type": "markdown", "metadata": { "id": "c001c9a5b5f9" }, "source": [ "### Gather all of the Vertex AI Search results\n", "\n", "- Read in CSV as Pandas DataFrame\n", "- Send questions to Vertex AI Search\n", "- Load Summary, top 5 docs, extractive answers, extractive segments to DataFrame\n", "- Output DataFrame to TSV\n", " - [Example TSV](https://storage.googleapis.com/github-repo/search/bulk-question-answering/bulk_question_answering_output.tsv)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "id": "c80855547d28" }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>Query</th>\n", " <th>Golden Doc</th>\n", " <th>Golden Doc Page Number</th>\n", " <th>Golden Answer</th>\n", " <th>Top 5 Docs</th>\n", " <th>Top 5 extractive answers</th>\n", " <th>Top 5 extractive segments</th>\n", " <th>Answer / Summary</th>\n", " <th>Feedback from customer / account team about returned docs and answer</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>What was Google's revenue in 2021?</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>Doc: gs://cloud-samples-data/gen-app-builder/s...</td>\n", " <td>Google Cloud had an Operating Loss of $890 mil...</td>\n", " <td>Within Other Revenues, we are pleased with the...</td>\n", " <td>Google's revenue for the full year 2021 was $5...</td>\n", " <td>NaN</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>What was Google's revenue in 2022?</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>Doc: gs://cloud-samples-data/gen-app-builder/s...</td>\n", " <td>Other Revenues were $8.2 billion, up 22%, driv...</td>\n", " <td>Let me now turn to our segment financial resul...</td>\n", " <td>Google's total revenue was $282.8 billion in 2...</td>\n", " <td>NaN</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " Query Golden Doc Golden Doc Page Number \\\n", "0 What was Google's revenue in 2021? NaN NaN \n", "1 What was Google's revenue in 2022? NaN NaN \n", "\n", " Golden Answer Top 5 Docs \\\n", "0 NaN Doc: gs://cloud-samples-data/gen-app-builder/s... \n", "1 NaN Doc: gs://cloud-samples-data/gen-app-builder/s... \n", "\n", " Top 5 extractive answers \\\n", "0 Google Cloud had an Operating Loss of $890 mil... \n", "1 Other Revenues were $8.2 billion, up 22%, driv... \n", "\n", " Top 5 extractive segments \\\n", "0 Within Other Revenues, we are pleased with the... \n", "1 Let me now turn to our segment financial resul... \n", "\n", " Answer / Summary \\\n", "0 Google's revenue for the full year 2021 was $5... \n", "1 Google's total revenue was $282.8 billion in 2... \n", "\n", " Feedback from customer / account team about returned docs and answer \n", "0 NaN \n", "1 NaN " ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Open the CSV file and read column values\n", "df = pd.read_csv(\n", " \"gs://github-repo/search/bulk-question-answering/bulk_question_answering_input.csv\",\n", " header=0,\n", " dtype=str,\n", ")\n", "\n", "# Make Vertex AI Search request for each question\n", "df.apply(\n", " lambda row: answer_questions(row, PROJECT_ID, LOCATION, DATA_STORE_ID, top_n=5),\n", " axis=1,\n", ")\n", "\n", "# Output results to new TSV file\n", "df.to_csv(\"bulk_question_answering_output.tsv\", index=False, sep=\"\\t\")\n", "\n", "df" ] } ], "metadata": { "colab": { "name": "bulk_question_answering.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }