notebooks/3_LoadKnownGoodSQL.ipynb (370 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2024 Google LLC\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"# **Open Data QnA: Cache Known Good Queries in Vector Store**\n",
"\n",
"---\n",
"\n",
"This notebook shows how to cahce known good queries in a Vector Store that has already been set up using [1_SetUpVectorStore.ipynb](1_SetUpVectorStore.ipynb). The queries are loaded into the vector store from the csv files (/scripts/known_good_sql.csv)\n",
"\n",
"Supported vector stores: \n",
"- pgvector on PostgreSQL \n",
"- BigQuery vector\n",
"\n",
"\n",
"The notebook covers the following steps: \n",
"> 1. Clean an existing embeddings table for known good queries (if loading_mode = 'refresh')\n",
"\n",
"> 2. Add known good queries from csv file to the embeddings table in the vector store"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🚧 **0. Pre-requisites**\n",
"\n",
"Make sure that you have completed the intial setup process using [1_SetUpVectorStore.ipynb](1_SetUpVectorStore.ipynb). If the 1_SetUpVectorStore notebook has been run successfully, the following are set up:\n",
"* GCP project and all the required IAM permissions\n",
"\n",
"* **Environment to run the solution**\n",
"\n",
"* Data source and Vector store for the solution\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ⚙️ **1. Retrieve Configuration Parameters**\n",
"The notebook will load all the configuration parameters from the `config.ini` file in the root directory. \n",
"Most of these parameters were set in the initial notebook `1_SetUpVectorStore.ipynb` and save to the 'config.ini file.\n",
"Use the below cells to retrieve these values and specify additional ones required for this notebook. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"module_path = os.path.abspath(os.path.join('..'))\n",
"sys.path.append(module_path)\n",
"\n",
"import configparser\n",
"config = configparser.ConfigParser()\n",
"config.read(module_path+'/config.ini')\n",
"\n",
"PROJECT_ID = config['GCP']['PROJECT_ID']\n",
"VECTOR_STORE = config['CONFIG']['VECTOR_STORE']\n",
"PG_DATABASE = config['PGCLOUDSQL']['PG_DATABASE']\n",
"PG_USER = config['PGCLOUDSQL']['PG_USER']\n",
"PG_REGION = config['PGCLOUDSQL']['PG_REGION'] \n",
"PG_INSTANCE = config['PGCLOUDSQL']['PG_INSTANCE'] \n",
"PG_PASSWORD = config['PGCLOUDSQL']['PG_PASSWORD']\n",
"BQ_OPENDATAQNA_DATASET_NAME = config['BIGQUERY']['BQ_OPENDATAQNA_DATASET_NAME']\n",
"BQ_LOG_TABLE_NAME = config['BIGQUERY']['BQ_LOG_TABLE_NAME'] \n",
"BQ_DATASET_REGION = config['BIGQUERY']['BQ_DATASET_REGION']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🔐 **2. Authenticate and Connect to Google Cloud Project**\n",
"Authenticate to Google Cloud as the IAM user logged into this notebook in order to access your Google Cloud Project.\n",
"\n",
"You can do this within Google Colab or using the Application Default Credentials in the Google Cloud CLI."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Project has been set to three-p-o\n"
]
}
],
"source": [
"\"\"\"Colab Auth\"\"\" \n",
"# from google.colab import auth\n",
"# auth.authenticate_user()\n",
"\n",
"\n",
"\"\"\"Google CLI Auth\"\"\"\n",
"# !gcloud auth application-default login\n",
"\n",
"\n",
"import google.auth\n",
"import os\n",
"\n",
"credentials, project_id = google.auth.default()\n",
"\n",
"os.environ['GOOGLE_CLOUD_QUOTA_PROJECT']=PROJECT_ID\n",
"os.environ['GOOGLE_CLOUD_PROJECT']=PROJECT_ID\n",
"\n",
"# Configure gcloud.\n",
"print(f'Project has been set to {PROJECT_ID}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 💾 **3. Cache Knwon Good Queries**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Format of the Known Good SQL File (known_good_sql.csv)\n",
"\n",
"prompt | sql | user_grouping [3 columns]\n",
"\n",
"prompt ==> Natural Language Question corresponding to query\n",
"\n",
"sql ==> SQL for the user question (Note that the sql should enclosed in quotes and only in single line. Please remove the line break)\n",
"\n",
"user_grouping ==>This name should exactly match the user_grouping for Postgres Source or Big Query as data_source_list.csv"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Known Good SQL Found at Path :: /home/admin_/Open_Data_QnA/scripts/known_good_sql.csv\n",
" prompt \\\n",
"0 Which country had the highest population in 2023? \n",
"1 List 10 countries with highest fertility rate \n",
"2 What is the life expectancy for men and women ... \n",
"3 Which countries had the highest and lowest inf... \n",
"4 What is the population distribution by age and... \n",
"5 What is the sex ratio at birth for China in 2023? \n",
"6 What is the world average sex ratio at birth i... \n",
"7 Which country has the highest boy to girl rati... \n",
"8 What are the top 10 counties with lowest morta... \n",
"9 What were the birth, death, and growth rates i... \n",
"10 In 2023, what was the population for countrie... \n",
"11 What are the top 3 countries with the highest ... \n",
"12 What are the top 5 coutries with highest popul... \n",
"13 Which country has highest male life expectancy? \n",
"14 What are the birth and death rates of the top ... \n",
"\n",
" sql user_grouping \n",
"0 SELECT country_name, midyear_population\\r\\nFRO... WorldCensus-cloudsql-pg \n",
"1 SELECT country_name, total_fertility_rate\\r\\nF... WorldCensus-cloudsql-pg \n",
"2 SELECT country_name, life_expectancy_male, lif... WorldCensus-cloudsql-pg \n",
"3 WITH RankedInfantMortality AS (\\r\\n SELECT\\r\\... WorldCensus-cloudsql-pg \n",
"4 SELECT age, sex, population\\r\\nFROM census_bur... WorldCensus-cloudsql-pg \n",
"5 SELECT country_name, sex_ratio_at_birth, year\\... WorldCensus-cloudsql-pg \n",
"6 SELECT AVG(sex_ratio_at_birth) AS world_avg_se... WorldCensus-cloudsql-pg \n",
"7 SELECT country_name, sex_ratio_at_birth\\r\\nFRO... WorldCensus-cloudsql-pg \n",
"8 SELECT country_name, mortality_rate_under5\\r\\n... WorldCensus-cloudsql-pg \n",
"9 SELECT year, crude_birth_rate, crude_death_rat... WorldCensus-cloudsql-pg \n",
"10 SELECT mp.midyear_population as Population, mp... WorldCensus-cloudsql-pg \n",
"11 SELECT mp.country_name, mp.midyear_population,... WorldCensus-cloudsql-pg \n",
"12 SELECT mp.country_name, \\r\\n mp.midyear_... WorldCensus-cloudsql-pg \n",
"13 SELECT country_name, life_expectancy_male\\r\\nF... WorldCensus-cloudsql-pg \n",
"14 SELECT bdg.country_name, bdg.crude_birth_rate,... WorldCensus-cloudsql-pg \n",
"Known Good SQLs Loaded into a Dataframe\n"
]
}
],
"source": [
"# Find the csv file and load as dataframe\n",
"import pandas as pd\n",
"\n",
"current_dir = os.getcwd()\n",
"root_dir = os.path.expanduser('~') # Start at the user's home directory\n",
"\n",
"while current_dir != root_dir:\n",
" for dirpath, dirnames, filenames in os.walk(current_dir):\n",
" config_path = os.path.join(dirpath, 'known_good_sql.csv')\n",
" if os.path.exists(config_path):\n",
" file_path = config_path # Update root_dir to the found directory\n",
" break # Stop outer loop once found\n",
"\n",
" current_dir = os.path.dirname(current_dir)\n",
"\n",
"print(\"Known Good SQL Found at Path :: \"+file_path)\n",
"\n",
"# Load the file\n",
"df_kgq = pd.read_csv(file_path)\n",
"df_kgq = df_kgq.loc[:, [\"prompt\", \"sql\", \"user_grouping\"]]\n",
"df_kgq = df_kgq.dropna()\n",
"print(df_kgq)\n",
"\n",
"print('Known Good SQLs Loaded into a Dataframe')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Specify mode for loading the known good sql\n",
"\n",
"The known good sql can loaded in two modes:\n",
"* Append mode: Apended to the existing KGQ in the vector store \n",
"* Refresh mode: Delete the existing KGQ and create of fresh copy from KGQ in known_good_sql.csv file"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"loading_mode = 'refresh' # Options 'append' or 'refresh'\n",
"assert loading_mode in {'append', 'refresh'}, \"⚠️ Invalid loading_mode. Must be 'append' and 'refresh'\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"current dir: /home/admin_/Open_Data_QnA/notebooks\n",
"root_dir set to: /home/admin_/Open_Data_QnA\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
"I0000 00:00:1721735219.442999 2463 config.cc:230] gRPC experiments enabled: call_status_override_on_cancellation, event_engine_dns, event_engine_listener, http2_stats_fix, monitoring_experiment, pick_first_new, trace_record_callops, work_serializer_clears_time_cache\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Adding Known Good Queries to the Vector store.....\n",
"Done!!\n"
]
}
],
"source": [
"# If you have Known Good Queries, load them to known_good_sql.csv file; \n",
"# These will be used as few shot examples for query generation. \n",
"\n",
"from embeddings.kgq_embeddings import setup_kgq_table, store_kgq_embeddings\n",
"\n",
"if loading_mode == 'refresh':\n",
" # Delete any old tables and create a new table to KGQ embeddings\n",
" if VECTOR_STORE=='bigquery-vector':\n",
" await(setup_kgq_table(project_id=PROJECT_ID,\n",
" instance_name=None,\n",
" database_name=None,\n",
" schema=BQ_OPENDATAQNA_DATASET_NAME,\n",
" database_user=None,\n",
" database_password=None,\n",
" region=BQ_DATASET_REGION,\n",
" VECTOR_STORE = VECTOR_STORE\n",
" ))\n",
"\n",
" elif VECTOR_STORE=='cloudsql-pgvector':\n",
" await(setup_kgq_table(project_id=PROJECT_ID,\n",
" instance_name=PG_INSTANCE,\n",
" database_name=PG_DATABASE,\n",
" schema=None,\n",
" database_user=PG_USER,\n",
" database_password=PG_PASSWORD,\n",
" region=PG_REGION,\n",
" VECTOR_STORE = VECTOR_STORE\n",
" ))\n",
"\n",
"\n",
"print(\"Adding Known Good Queries to the Vector store.....\")\n",
"# Add KGQ to the vector store\n",
"if VECTOR_STORE=='bigquery-vector':\n",
" await(store_kgq_embeddings(df_kgq,\n",
" project_id=PROJECT_ID,\n",
" instance_name=None,\n",
" database_name=None,\n",
" schema=BQ_OPENDATAQNA_DATASET_NAME,\n",
" database_user=None,\n",
" database_password=None,\n",
" region=BQ_DATASET_REGION,\n",
" VECTOR_STORE = VECTOR_STORE\n",
" ))\n",
"\n",
"elif VECTOR_STORE=='cloudsql-pgvector':\n",
" await(store_kgq_embeddings(df_kgq,\n",
" project_id=PROJECT_ID,\n",
" instance_name=PG_INSTANCE,\n",
" database_name=PG_DATABASE,\n",
" schema=None,\n",
" database_user=PG_USER,\n",
" database_password=PG_PASSWORD,\n",
" region=PG_REGION,\n",
" VECTOR_STORE = VECTOR_STORE\n",
" ))\n",
"print('Done!!')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "talktodata-Fy2pM2BF-py3.9",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}