nl2sql_src/nl2sql_generic.py (746 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import json
import traceback
import sqlalchemy
import pandas as pd
# import langchain
import sqlglot
# from prompts import *
from prompts import Table_filtering_prompt, Table_filtering_prompt_promptonly
from prompts import Auto_verify_sql_prompt, Sql_Generation_prompt
from prompts import Sql_Generation_prompt_few_shot, additional_context_prompt
from prompts import Table_info_template, join_prompt_template
from prompts import join_prompt_template_one_shot, multi_table_prompt
from prompts import follow_up_prompt, Sql_Generation_prompt_few_shot_multiturn
from prompts import Result2nl_insight_prompt, Result2nl_prompt
from langchain_google_vertexai import VertexAI
from google.cloud import bigquery
from nl2sql_query_embeddings import PgSqlEmb, Nl2Sql_embed
import os
from loguru import logger
# from google.cloud import aiplatform
from vertexai.preview.generative_models import GenerativeModel
# from vertexai.preview.generative_models import GenerationConfig
# from json import loads, dumps
# from vertexai.language_models import TextGenerationModel
from vertexai.language_models import CodeChatSession
from vertexai.language_models import CodeChatModel
client = bigquery.Client()
class Nl2sqlBq:
"""
NL2SQL Lite SQL Generator class
"""
def __init__(self,
project_id,
dataset_id,
metadata_json_path=None,
model_name="gemini-pro",
tuned_model=True):
"""
Init function
"""
self.dataset_id = f"{project_id}.{dataset_id}"
self.metadata_json = None
self.model_name = model_name
if model_name == 'text-bison@002' and tuned_model:
# self.llm = VertexAI(temperature=0,
# model_name=self.model_name,
# tuned_model_name='projects/862253555914/\
# locations/us-central1/\
# models/7566417909400993792',
# max_output_tokens=1024)
tuned_model_name = 'projects/174482663155/locations/' + \
'us-central1/models/6975883408262037504'
self.llm = VertexAI(temperature=0,
model_name=self.model_name,
tuned_model_name=tuned_model_name,
max_output_tokens=1024)
else:
self.llm = VertexAI(temperature=0,
model_name=self.model_name,
max_output_tokens=1024)
logger.info(f"Current LLM model : {self.model_name}")
# self.llm = VertexAI(temperature=0,
# model_name=self.model_name,
# tuned_model_name='projects/862253555914/\
# locations/us-central1/models/\
# 7566417909400993792',
# max_output_tokens=1024)
self.engine = sqlalchemy.engine.create_engine(
f"bigquery://{self.dataset_id.replace('.', '/')}")
if metadata_json_path:
f = open(metadata_json_path, encoding="utf-8")
self.metadata_json = json.loads(f.read())
def init_pgdb(self,
proj_id,
loc,
pg_inst,
pg_db,
pg_uname,
pg_pwd,
pg_table,
index_file='saved_index_pgdata'):
"""
Initialising the PG DB
"""
self.pge = PgSqlEmb(proj_id,
loc,
pg_inst,
pg_db,
pg_uname,
pg_pwd,
pg_table)
def get_all_table_names(self):
"""
Provides list of table names in dataset
"""
tables = client.list_tables(self.dataset_id)
all_table_names = [table.table_id for table in tables]
return all_table_names
def get_column_value_examples(self, tname, column_name, enum_option_limit):
"""
Provide example values for string columns
"""
examples_str = ""
if pd.read_sql(
sql=f"SELECT COUNT(DISTINCT {column_name}) <=\
{enum_option_limit} FROM {tname}",
con=self.engine).values[0][0]:
sql_string = f"SELECT DISTINCT {column_name} AS vals FROM {tname}"
examples_str = "It contains values : \"" + ("\", \"".join(
filter(
lambda x: x is not None,
pd.read_sql(
sql=sql_string,
con=self.engine
)["vals"].to_list()
)
)
) + "\"."
return examples_str
def create_metadata_json(self,
metadata_json_dest_path,
data_dict_path=None,
col_values_distribution=False,
enum_option_limit=10):
"""
Creates metadata json file
"""
try:
data_dict = dict()
if data_dict_path:
f = open(data_dict_path, encoding="utf-8")
data_dict = json.loads(f.read())
table_ls = self.get_all_table_names()
metadata_json = dict()
for table_name in table_ls:
table = client.get_table(f"{self.dataset_id}.{table_name}")
table_description = ""
if table_name in data_dict and data_dict[table_name].strip():
table_description = data_dict[table_name]
elif table.description:
table_description = table.description
columns_info = dict()
for schema in table.schema:
schema_description = ""
if f"{table_name}.{schema.name}" in data_dict and \
data_dict[f"{table_name}.{schema.name}"].strip():
schema_description = data_dict[
f"{table_name}.{schema.name}"]
elif schema.description:
schema_description = schema.description
columns_info[schema.name] = {
"Name": schema.name,
"Type": schema.field_type,
"Description": schema_description,
"Examples": ""
}
if col_values_distribution and \
schema.field_type == "STRING":
all_examples = self.get_column_value_examples(
table_name, schema.name, enum_option_limit)
columns_info[schema.name]["Examples"] = all_examples
metadata_json[table_name] = {"Name": table_name,
"Description": table_description,
"Columns": columns_info}
with open(metadata_json_dest_path, 'w', encoding="utf-8") as f:
json.dump(metadata_json, f)
self.metadata_json = metadata_json
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def table_filter(self, question):
"""
This function selects the relevant table(s) to the provided question
based on their description-keywords.
It assists in selecting a table from a list of tables based on
their description-keywords.
It presents a prompt containing a list of table names along
with their corresponding description-keywords.
The function uses a text-based model (text_bison) to analyze
the prompt and extract the selected table name(s).
Parameters:
- question (str): The question for which the relevant table
need to be identified.
Returns:
list: A list of table names most likely relevant to the provided
question.
"""
only_tables_info = ""
for table in self.metadata_json:
only_tables_info = only_tables_info + f"{table} | \
{self.metadata_json[table]['Description']}\n"
prompt = Table_filtering_prompt.format(
only_tables_info=only_tables_info,
question=question
)
result = self.llm.invoke(prompt)
segments = result.split(',')
tables_list = []
for segment in segments:
segment = segment.strip()
if ':' in segment:
value = segment.split(':')[-1].strip()
tables_list.append(value.strip())
elif '\n' in segment:
value = segment.split('\n')[-1].strip()
tables_list.append(value.strip())
else:
tables_list.append(segment)
return tables_list
def table_filter_promptonly(self, question):
"""
This function returns the prompt for the base question
in the Multi-turn execution
Parameters:
- question (str): The question for which the relevant table
need to be identified.
Returns:
list: A list of table names most likely relevant to the
provided question.
"""
only_tables_info = ""
for table in self.metadata_json:
only_tables_info = only_tables_info + f"{table} | \
{self.metadata_json[table]['Description']}\n"
prompt = Table_filtering_prompt_promptonly.format(
only_tables_info=only_tables_info
)
return prompt
def case_handler_transform(self, sql_query: str) -> str:
"""
This function implements case-handling mechanism transformation
for a SQL query.
Parameters:
- sql_query (str): The original SQL query.
Returns:
str: The transformed SQL query with case-handling mechanism applied,
or the original query if no transformation is needed.
"""
# print("Case handller transform", sql_query)
node = sqlglot.parse_one(sql_query)
if (
isinstance(node, sqlglot.expressions.EQ) and
node.find_ancestor(sqlglot.expressions.Where) and
len(operands := list(node.unnest_operands())) == 2 and
isinstance(
literal := operands.pop(), sqlglot.expressions.Literal
) and
isinstance(predicate := operands.pop(), sqlglot.expressions.Column)
):
transformed_query =\
sqlglot.parse_one(f"LOWER({predicate}) =\
'{literal.this.lower()}'")
return str(transformed_query)
else:
return sql_query
def add_dataset_to_query(self, sql_query):
"""
This function adds the specified dataset prefix to the tables
in the FROM clause of a SQL query.
Parameters:
- dataset (str): The dataset name to be added as a prefix.
- sql_query (str): The original SQL query.
Returns:
str: Modified SQL query with the specified dataset prefix
added to the tables in the FROM clause.
"""
logger.info(f"Original query : {sql_query}")
dataset = self.dataset_id
if sql_query:
sql_query = sql_query.replace('`', '')
# Define a regular expression pattern to match the FROM clause
pattern = re.compile(r'\bFROM\b\s+(\w+)', re.IGNORECASE)
# Find all matches of the pattern in the SQL query
matches = pattern.findall(sql_query)
# Iterate through matches and replace the table name
for match in matches:
# check text following the match if it is a complete table name
next_text = sql_query.split(match)[1].split('\n')[0]
next_text = next_text.split(' ')[0]
# Check if the previous word is not DAY, YEAR, or MONTH
if re.search(r'\b(?:DAY|YEAR|MONTH)\b',
sql_query[:sql_query.find(match)],
re.IGNORECASE) is None:
# Replace the next word after FROM with dataset.table
if match == dataset.split('.')[0]:
# checking if in generated SQL, table
# includes the project-id and dataset or not
replacement = f'`{match}'
else:
sql_query = sql_query.replace(next_text, '')
replacement = f'{dataset}.`{match}{next_text}`'
# replacement = f'{dataset}.{match}'
sql_query = re.sub(r'\bFROM\b\s+' + re.escape(match),
f'FROM {replacement}',
sql_query,
flags=re.IGNORECASE
)
if match == dataset.split('.')[0]:
sql_query = sql_query.replace(f'{match}{next_text}',
f'{match}{next_text}`'
)
sql_query = sql_query.replace('CAST', 'SAFE_CAST')
sql_query = sql_query.replace('SAFE_SAFE_CAST', 'SAFE_CAST')
return sql_query
else:
return ""
def generate_sql(self, question, table_name=None, logger_file="log.txt"):
"""
Main function which converts NL to SQL
"""
# step-1 table selection
try:
if not table_name:
if len(self.metadata_json.keys()) > 1:
table_list = self.table_filter(question)
table_name = table_list[0]
else:
table_name = list(self.metadata_json.keys())[0]
table_json = self.metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}
\n"""
columns_info = columns_info + column_info
sql_prompt = Sql_Generation_prompt.format(
table_name=table_json["Name"],
table_description=table_json["Description"],
columns_info=columns_info,
question=question
)
response = self.llm.invoke(sql_prompt)
sql_query = response.replace('sql', '').replace('```', '')
# sql_query = self.case_handler_transform(sql_query)
sql_query = self.add_dataset_to_query(sql_query)
with open(logger_file, 'a', encoding="utf-8") as f:
f.write(f">>\nModel:{self.model_name} \n\nQuestion: {question}\
\n\nPrompt:{sql_prompt} \nSql_query:{sql_query}<<\n")
if sql_query.strip().startswith("Response:"):
sql_query = sql_query.split(":")[1].strip()
return sql_query
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def generate_sql_few_shot(self,
question,
table_name=None,
logger_file="log.txt"):
"""
Main function which converts NL to SQL using few shot prompting
"""
# step-1 table selection
try:
if not table_name:
if len(self.metadata_json.keys()) > 1:
table_list = self.table_filter(question)
table_name = table_list[0]
else:
table_name = list(self.metadata_json.keys())[0]
table_json = self.metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}\n"""
columns_info = columns_info + column_info
# few_shot_json = self.pge.search_matching_queries(question)
embed = Nl2Sql_embed()
few_shot_json = embed.search_matching_queries(question)
logger.info(f"Few sjot examples : {few_shot_json}")
few_shot_examples = ""
for item in few_shot_json:
example_string = f"Question: {item['question']}"
few_shot_examples += example_string + "\n"
example_string = f"SQL : {item['sql']} "
few_shot_examples += example_string + "\n\n"
sql_prompt = Sql_Generation_prompt_few_shot.format(
table_name=table_json["Name"],
table_description=table_json["Description"],
columns_info=columns_info,
few_shot_examples=few_shot_examples,
question=question)
response = self.llm.invoke(sql_prompt)
sql_query = response.replace('sql', '').replace('```', '')
# sql_query = self.case_handler_transform(sql_query)
sql_query = self.add_dataset_to_query(sql_query)
with open(logger_file, 'a', encoding="utf-8") as f:
f.write(f">>\nModel:{self.model_name} \n\nQuestion: {question}\
\n\nPrompt:{sql_prompt} \n\nSql_query:{sql_query}<\n")
if sql_query.strip().startswith("Response:"):
sql_query = sql_query.split(":")[1].strip()
return sql_query
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def generate_sql_few_shot_promptonly(self,
question,
table_name=None,
prev_sql="",
logger_file="log.txt"):
"""
Returns only the few shot prompt
"""
# step-1 table selection
try:
if not table_name:
if len(self.metadata_json.keys()) > 1:
table_list = self.table_filter(question)
table_name = table_list[0]
else:
table_name = list(self.metadata_json.keys())[0]
table_json = self.metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}\n"""
columns_info = columns_info + column_info
few_shot_json = self.pge.search_matching_queries(question)
few_shot_examples = ""
for item in few_shot_json:
example_string = f"Question: {item['question']}"
few_shot_examples += example_string + "\n"
example_string = f"SQL : {item['sql']} "
few_shot_examples += example_string + "\n\n"
if prev_sql:
additional_context = additional_context_prompt.format(
prev_sql=prev_sql
)
else:
additional_context = ""
sql_prompt = Sql_Generation_prompt_few_shot_multiturn.format(
table_name=table_json["Name"],
table_description=table_json["Description"],
columns_info=columns_info,
few_shot_examples=few_shot_examples,
question=question,
additional_context=additional_context
)
return sql_prompt
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def execute_query_old(self, query):
"""
This function executes an SQL query using the configured
BigQuery client.
Parameters:
- query (str): The SQL query to be executed.
Returns:
pandas.DataFrame: The result of the executed query as a DataFrame.
"""
try:
# Run the SQL query
query_job = client.query(query)
# Wait for the job to complete
query_job.result()
# Fetch the result if needed
results = query_job.to_dataframe()
return results
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def execute_query(self, query, dry_run=False):
"""
This function executes an SQL query using the configured
BigQuery client.
Parameters:
- query (str): The SQL query to be executed.
Returns:
pandas.DataFrame: The result of the executed query as a DataFrame.
"""
if dry_run:
job_config = bigquery.QueryJobConfig(dry_run=True,
use_query_cache=False)
query_job = client.query(query, job_config=job_config)
if query_job.total_bytes_processed > 0:
logger.info("Query is valid")
return True, 'Query is valid'
else:
return False, 'Invalid query. Regenerate'
else:
try:
# Run the SQL query
query_job = client.query(query)
# Wait for the job to complete
query_job.result()
# Fetch the result if needed
results = query_job.to_dataframe()
return results
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def self_reflection(self, question, query, max_tries=5):
"""
Retries the query generation process in case of failure
for the specified number of times
"""
status, _ = self.execute_query(query, dry_run=True)
good_sql = False
if not status:
# Repeat generation of the sql
iter = 0
while iter < max_tries or good_sql:
prompt = self.generate_sql_few_shot_promptonly(question,
table_name="",
prev_sql=query)
query = self.invoke_llm(prompt)
good_sql, msg = self.execute_query(query, dry_run=True)
iter += 1
return good_sql, query
def text_to_sql_execute(self,
question,
table_name=None,
logger_file="log.txt"):
"""
Converts text to sql and also executes sql query
"""
try:
# query = self.text_to_sql(question,
# table_name,logger_file = logger_file)
query = self.generate_sql(question,
table_name,
logger_file=logger_file
)
results = self.execute_query(query)
return results
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def text_to_sql_execute_few_shot(self,
question,
table_name=None,
logger_file="log.txt"):
"""
Converts text to sql and also executes sql query
"""
try:
query = self.generate_sql_few_shot(question,
table_name,
logger_file=logger_file
)
logger.info(f"Executing query : {query}")
results = self.execute_query(query)
return results, query
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def result2nl(self, result, question, insight=True):
"""
The function converts an SQL query result into an insightful
and well-explained natural language summary, using text-bison model.
Parameters:
- result (str): The result of the SQL query.
- question (str): The natural language question corresponding
to the SQL query.
Returns:
str: A natural language summary of the SQL query result.
"""
try:
if insight:
prompt = Result2nl_insight_prompt.format(question=question,
result=str(result)
)
else:
prompt = Result2nl_prompt.format(question=question,
result=str(result)
)
return self.llm.invoke(prompt)
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def auto_verify(self, nl_description, ground_truth, llm_amswer):
"""
This function verifies the accuracy of SQL query based on a natural
language description and a ground truth query, using text-bison model.
Parameters:
- nl_description (str): The natural language description of the
SQL query.
- ground_truth (str): The ground truth SQL query.
- llm_amswer (str): The student's generated SQL query for validation.
Returns:
str: "Yes" if the student's answer matches the ground truth
and fits the NL description correctly,"No" otherwise.
"""
prompt = Auto_verify_sql_prompt.format(nl_description=nl_description,
ground_truth=ground_truth,
llm_amswer=llm_amswer
)
return self.llm.invoke(prompt)
def batch_run(self,
test_file_name,
output_file_name,
execute_query=False,
result2nl=False,
insight=True,
logger_file="log.txt"):
"""
This function procesess a batch of questions from a test file,
generate SQL queries, and evaluate their accuracy.
It reads questions from a CSV file, generates SQL queries using the
`gen_sql` function,
evaluates the accuracy of the generated queries using the `auto_verify`
function,
and optionally converts SQL queries to natural language
using the `sql2result` and `result2nl` functions.
The results are stored in a DataFrame and saved to a CSV file in the
'output' directory,
with a timestamped filename.
Parameters:
- test_file_name (str):
The name of the CSV file containing test questions and ground truth
SQL queries.
- sql2nl (bool, optional):
Flag to convert SQL queries to natural language. Defaults to False.
Returns:
pandas.DataFrame: A DataFrame containing question, ground truth SQL,
LLM-generated SQL, LLM rating, SQL execution result, and NL response.
"""
try:
questions = pd.read_csv(test_file_name)
out = []
columns = ['question',
'ground_truth',
'llm_response',
'llm_rating'
]
if execute_query:
columns.append('sql_result')
if result2nl:
columns.append('nl_response')
for _, row in questions.iterrows():
table_name = None
if row["table"].strip():
table_name = row["table"]
question = row["question"]
# print(question)
sql_gen = self.generate_sql(question,
table_name=table_name,
logger_file=logger_file
)
# print(sql_gen)
rating = self.auto_verify(question,
row["ground_truth_sql"], sql_gen
)
row_result = [question,
row["ground_truth_sql"], sql_gen, rating]
if execute_query:
result = self.execute_query(sql_gen)
# print(result)
row_result.append(result)
if execute_query and result2nl:
nl = self.result2nl(result, question, insight=insight)
row_result.append(nl)
out.append(row_result)
# print("\n\n")
df = pd.DataFrame(out, columns=columns)
df.to_csv(output_file_name, index=False)
return df
except Exception as exc:
raise Exception(traceback.print_exc()) from exc
def table_details(self, table_name):
"""
Cretes the Table details required for Joins
"""
f = open(self.metadata_json, encoding="utf-8")
metadata_json = json.loads(f.read())
table_json = metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""
{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}\n"""
columns_info = columns_info + column_info
prompt = Table_info_template.format(
table_name=table_name,
table_description=metadata_json[table_name]['Description'],
columns_info=columns_info
)
return prompt
def get_join_prompt(self,
dataset,
table_1_name,
table_2_name,
question,
sample_question=None,
sample_sql=None,
one_shot=False):
"""
Crete the prompt for Joins
"""
prompt = ""
table_1 = self.table_details(table_1_name)
table_2 = self.table_details(table_2_name)
if one_shot:
prompt = join_prompt_template_one_shot.format(
data_set=dataset,
table_1=table_1,
table_2=table_2,
sample_question=sample_question,
sample_sql=sample_sql,
question=question
)
else:
prompt = join_prompt_template.format(data_set=dataset,
table_1=table_1,
table_2=table_2,
question=question)
return prompt
def invoke_llm(self, prompt):
"""
Invoke the LLM
"""
response = self.llm.invoke(prompt)
sql_query = response.replace('sql', '').replace('```', '')
# sql_query = self.case_handler_transform(sql_query)
sql_query = self.add_dataset_to_query(sql_query)
# with open(logger_file, 'a',encoding="utf-8") as f:
# f.write(f">>>>\nModel:{self.model_name} \n\nQuestion: {question}\
# \n\nPrompt:{join_prompt} \n\nSql_query:{sql_query}<<<<\n\n\n")
return sql_query
def multi_turn_table_filter(self,
table_1_name,
table_2_name,
sample_question,
sample_sql,
question):
"""
Table filter for multi-turn prompting
"""
table_info = self.table_filter_promptonly(question)
prompt = multi_table_prompt.format(table_info=table_info,
example_question=sample_question,
example_sql=sample_sql,
question=question,
table_1_name=table_1_name,
table_2_name=table_2_name)
model = GenerativeModel("gemini-1.0-pro")
multi_chat = model.start_chat()
_ = multi_chat.send_message(prompt) # response1
response2 = multi_chat.send_message(follow_up_prompt)
try:
identified_tables = response2.candidates[0].content.parts[0].text
except Exception:
identified_tables = ""
return identified_tables
def gen_and_exec_and_self_correct_sql(self,
prompt,
genai_model_name="GeminiPro",
max_tries=5,
return_all=False):
"""
Wrapper function for Standard, Multi-turn and Self Correct
approach of SQL generation
"""
tries = 0
error_messages = []
prompts = [prompt]
successful_queries = []
TEMPERATURE = 0.3
MAX_OUTPUT_TOKENS = 8192
MODEL_NAME = 'codechat-bison-32k'
code_gen_model = CodeChatModel.from_pretrained(MODEL_NAME)
model = GenerativeModel("gemini-1.0-pro")
if genai_model_name == "GeminiPro":
chat_session = model.start_chat()
else:
chat_session = CodeChatSession(model=code_gen_model,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS
)
while tries < max_tries:
try:
if genai_model_name == "GeminiPro":
response = chat_session.send_message(prompt)
else:
response = chat_session.send_message(
prompt,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS
)
generated_sql_query = response.text
generated_sql_query = '\n'.join(
generated_sql_query.split('\n')[1:-1]
)
generated_sql_query = self.case_handler_transform(
generated_sql_query
)
generated_sql_query = self.add_dataset_to_query(
generated_sql_query
)
df = client.query(generated_sql_query).to_dataframe()
successful_queries.append({
"query": generated_sql_query,
"dataframe": df
})
if len(successful_queries) > 1:
prompt = f"""Modify the last successful SQL query by
making changes to it and optimizing it for latency.
ENSURE that the NEW QUERY is DIFFERENT from the
previous one while prioritizing faster execution.
Reference the tables only from the above given
project and dataset
The last successful query was:
{successful_queries[-1]["query"]}"""
except Exception as e:
msg = str(e)
error_messages.append(msg)
prompt = f"""Encountered an error: {msg}.
To address this, please generate an alternative SQL
query response that avoids this specific error.
Follow the instructions mentioned above to
remediate the error.
Modify the below SQL query to resolve the issue and
ensure it is not a repetition of all previously
generated queries.
{generated_sql_query}
Ensure the revised SQL query aligns precisely with the
requirements outlined in the initial question.
Keep the table names as it is. Do not change hyphen
to underscore character
Additionally, please optimize the query for latency
while maintaining correctness and efficiency."""
prompts.append(prompt)
tries += 1
if len(successful_queries) == 0:
return {
"error": "All attempts exhausted.",
"prompts": prompts,
"errors": error_messages
}
else:
df = pd.DataFrame(
[(q["query"], q["dataframe"])
for q in successful_queries], columns=["Query", "Result"]
)
return {
"dataframe": df
}
def generate_sql_with_join(self,
dataset,
table_1_name,
table_2_name,
question,
example_table1,
example_table2,
sample_question=None,
sample_sql=None,
one_shot=False,
join_gen="STANDARD"):
gen_join_sql = ""
match join_gen:
case 'STANDARD':
if not one_shot:
# Zero-shot Join query generation
join_prompt = self.get_join_prompt(dataset,
table_1_name,
table_2_name,
question)
gen_join_sql = self.invoke_llm(join_prompt)
else:
# One-shot Join query generation
join_prompt_one_shot = self.get_join_prompt(
dataset,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
gen_join_sql = self.invoke_llm(
join_prompt_one_shot
)
case 'MULTI_TURN':
table_1_name, table_2_name = \
self.multi_turn_table_filter(
table_1_name=example_table1,
table_2_name=example_table2,
sample_question=sample_question,
sample_sql=sample_sql,
question=question
)
# One-shot Join query generation
join_prompt_one_shot = self.get_join_prompt(
data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
gen_join_sql = self.invoke_llm(
join_prompt_one_shot
)
case 'SELF_CORRECT':
join_prompt_one_shot = self.get_join_prompt(
data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
# Self-Correction Approach
responses = self.gen_and_exec_and_self_correct_sql(
join_prompt_one_shot
)
gen_join_sql = responses[0]['query']
return gen_join_sql
if __name__ == '__main__':
project_id = os.environ['PROJECT_ID']
dataset_id = os.environ['DATASET_ID']
print("Info =", project_id, dataset_id)
meta_data_json_path = "utils/metadata_cache.json"
nl2sqlbq_client = Nl2sqlBq(
project_id=project_id,
dataset_id=dataset_id,
metadata_json_path=meta_data_json_path,
# "../utils/metadata_cache.json",
model_name="text-bison"
# model_name="code-bison"
)
question = "What is the average, minimum, and maximum age \
of all singers from France?"
sql_query, _ = nl2sqlbq_client.text_to_sql_execute_few_shot(
question,
'medi-cal-and-calfresh-enrollment'
)
print("Generated query == ", sql_query)
nl_resp = nl2sqlbq_client.result2nl(sql_query, question)
print("Response in NL = ", nl_resp)
table_1_name = ""
table_2_name = ""
sample_question = ""
sample_sql = ""
data_set = ""
example_table_1 = ""
example_table_2 = ""
# Zero-shot Join query generation
join_prompt = nl2sqlbq_client.get_join_prompt(data_set,
table_1_name,
table_2_name,
question)
gen_join_sql = nl2sqlbq_client.invoke_llm(join_prompt)
print("SQL query wiith Join - ", gen_join_sql)
# One-shot Join query generation
join_prompt_one_shot = nl2sqlbq_client.get_join_prompt(data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True)
gen_join_sql = nl2sqlbq_client.invoke_llm(join_prompt_one_shot)
print("SQL query wiith Join - ", gen_join_sql)
# Table Identification with Multi-turn approach
example_table_1 = ""
example_table_2 = ""
table_1_name, table_2_name = nl2sqlbq_client.multi_turn_table_filter(
table_1_name=example_table_1,
table_2_name=example_table_2,
sample_question=sample_question,
sample_sql=sample_sql,
question=question
)
# One-shot Join query generation
join_prompt_one_shot = nl2sqlbq_client.get_join_prompt(data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True)
gen_join_sql = nl2sqlbq_client.invoke_llm(join_prompt_one_shot)
print("SQL query wiith Join - ", gen_join_sql)
# Self-Correction Approach
responses = nl2sqlbq_client.gen_and_exec_and_self_correct_sql(
join_prompt_one_shot
)
print(responses)
# Common function to perform either of the operations
gen_join_sql = nl2sqlbq_client.generate_sql_with_join(
data_set,
table_1_name,
table_2_name,
question,
example_table_1,
example_table_2,
sample_question,
sample_sql,
True,
"STANDARD" # STANDARD or MULTI_TURN or SELF_CORRECT
)