agents/BuildSQLAgent.py (95 lines of code) (raw):
from abc import ABC
from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel, Content, Part, GenerationConfig
from .core import Agent
import pandas as pd
import json
from datetime import datetime
from dbconnectors import pgconnector,bqconnector,firestoreconnector
from utilities import PROMPTS, format_prompt
from google.cloud.aiplatform import telemetry
import vertexai
from utilities import PROJECT_ID, PG_REGION
from vertexai.generative_models import GenerationConfig
vertexai.init(project=PROJECT_ID, location=PG_REGION)
class BuildSQLAgent(Agent, ABC):
agentType: str = "BuildSQLAgent"
def __init__(self, model_id = 'gemini-1.5-pro'):
super().__init__(model_id=model_id)
def build_sql(self,source_type,user_grouping, user_question,session_history,tables_schema,columns_schema, similar_sql, max_output_tokens=2048, temperature=0.4, top_p=1, top_k=32):
not_related_msg=f'''select 'Question is not related to the dataset' as unrelated_answer;'''
if source_type=='bigquery':
from dbconnectors import bq_specific_data_types
specific_data_types = bq_specific_data_types()
else:
from dbconnectors import pg_specific_data_types
specific_data_types = pg_specific_data_types()
if f'usecase_{source_type}_{user_grouping}' in PROMPTS:
usecase_context = PROMPTS[f'usecase_{source_type}_{user_grouping}']
else:
usecase_context = "No extra context for the usecase is provided"
context_prompt = PROMPTS[f'buildsql_{source_type}']
context_prompt = format_prompt(context_prompt,
specific_data_types = specific_data_types,
not_related_msg = not_related_msg,
usecase_context = usecase_context,
similar_sql=similar_sql,
tables_schema=tables_schema,
columns_schema = columns_schema)
# print(f"Prompt to Build SQL: \n{context_prompt}")
# Chat history Retrieval
chat_history=[]
for entry in session_history:
timestamp = entry["timestamp"]
timestamp_str = timestamp.isoformat(timespec='auto')
user_message = Content(
parts=[Part.from_text(entry["user_question"])],
role="user"
)
bot_message = Content(
parts=[Part.from_text(entry["bot_response"])],
role="assistant"
)
chat_history.extend([user_message, bot_message]) # Add both to the history
# print("Chat History Retrieved")
if self.model_id == 'codechat-bison-32k':
with telemetry.tool_context_manager('opendataqna-buildsql-v2'):
chat_session = self.model.start_chat(context=context_prompt)
elif 'gemini' in self.model_id:
with telemetry.tool_context_manager('opendataqna-buildsql-v2'):
# print("SQL Builder Agent : " + str(self.model_id))
config = GenerationConfig(
max_output_tokens=max_output_tokens, temperature=temperature, top_p=top_p, top_k=top_k
)
chat_session = self.model.start_chat(history=chat_history,response_validation=False)
chat_session.send_message(context_prompt)
else:
raise ValueError('Invalid Model Specified')
if session_history is None or not session_history:
concated_questions = None
re_written_qe = None
previous_question = None
previous_sql = None
else:
concated_questions,re_written_qe=self.rewrite_question(user_question,session_history)
previous_question, previous_sql = self.get_last_sql(session_history)
build_context_prompt=f"""
Below is the previous user question from this conversation and its generated sql.
Previous Question: {previous_question}
Previous Generated SQL : {previous_sql}
Respond with
Generate SQL for User Question : {user_question}
"""
# print("BUILD CONTEXT ::: "+str(build_context_prompt))
with telemetry.tool_context_manager('opendataqna-buildsql-v2'):
response = chat_session.send_message(build_context_prompt, stream=False)
generated_sql = (str(response.text)).replace("```sql", "").replace("```", "")
generated_sql = (str(response.text)).replace("```sql", "").replace("```", "")
# print(generated_sql)
return generated_sql
def rewrite_question(self,question,session_history):
formatted_history=''
concat_questions=''
for i, _row in enumerate(session_history,start=1):
user_question = _row['user_question']
sql_query = _row['bot_response']
# print(user_question)
formatted_history += f"User Question - Turn :: {i} : {user_question}\n"
formatted_history += f"Generated SQL - Turn :: {i}: {sql_query}\n\n"
concat_questions += f"{user_question} "
# print(formatted_history)
context_prompt = f"""
Your main objective is to rewrite and refine the question passed based on the session history of question and sql generated.
Refine the given question using the provided session history to produce a queryable statement. The refined question should be self-contained, requiring no additional context for accurate SQL generation.
Make sure all the information is included in the re-written question
Below is the previous session history:
{formatted_history}
Question to rewrite:
{question}
"""
re_written_qe = str(self.generate_llm_response(context_prompt))
print("*"*25 +"Re-written question for the follow up:: "+"*"*25+"\n"+str(re_written_qe))
return str(concat_questions),str(re_written_qe)
def get_last_sql(self,session_history):
for entry in reversed(session_history):
if entry.get("bot_response"):
return entry["user_question"],entry["bot_response"]
return None