agents/BuildSQLAgent.py (95 lines of code) (raw):

from abc import ABC from vertexai.language_models import CodeChatModel from vertexai.generative_models import GenerativeModel, Content, Part, GenerationConfig from .core import Agent import pandas as pd import json from datetime import datetime from dbconnectors import pgconnector,bqconnector,firestoreconnector from utilities import PROMPTS, format_prompt from google.cloud.aiplatform import telemetry import vertexai from utilities import PROJECT_ID, PG_REGION from vertexai.generative_models import GenerationConfig vertexai.init(project=PROJECT_ID, location=PG_REGION) class BuildSQLAgent(Agent, ABC): agentType: str = "BuildSQLAgent" def __init__(self, model_id = 'gemini-1.5-pro'): super().__init__(model_id=model_id) def build_sql(self,source_type,user_grouping, user_question,session_history,tables_schema,columns_schema, similar_sql, max_output_tokens=2048, temperature=0.4, top_p=1, top_k=32): not_related_msg=f'''select 'Question is not related to the dataset' as unrelated_answer;''' if source_type=='bigquery': from dbconnectors import bq_specific_data_types specific_data_types = bq_specific_data_types() else: from dbconnectors import pg_specific_data_types specific_data_types = pg_specific_data_types() if f'usecase_{source_type}_{user_grouping}' in PROMPTS: usecase_context = PROMPTS[f'usecase_{source_type}_{user_grouping}'] else: usecase_context = "No extra context for the usecase is provided" context_prompt = PROMPTS[f'buildsql_{source_type}'] context_prompt = format_prompt(context_prompt, specific_data_types = specific_data_types, not_related_msg = not_related_msg, usecase_context = usecase_context, similar_sql=similar_sql, tables_schema=tables_schema, columns_schema = columns_schema) # print(f"Prompt to Build SQL: \n{context_prompt}") # Chat history Retrieval chat_history=[] for entry in session_history: timestamp = entry["timestamp"] timestamp_str = timestamp.isoformat(timespec='auto') user_message = Content( parts=[Part.from_text(entry["user_question"])], role="user" ) bot_message = Content( parts=[Part.from_text(entry["bot_response"])], role="assistant" ) chat_history.extend([user_message, bot_message]) # Add both to the history # print("Chat History Retrieved") if self.model_id == 'codechat-bison-32k': with telemetry.tool_context_manager('opendataqna-buildsql-v2'): chat_session = self.model.start_chat(context=context_prompt) elif 'gemini' in self.model_id: with telemetry.tool_context_manager('opendataqna-buildsql-v2'): # print("SQL Builder Agent : " + str(self.model_id)) config = GenerationConfig( max_output_tokens=max_output_tokens, temperature=temperature, top_p=top_p, top_k=top_k ) chat_session = self.model.start_chat(history=chat_history,response_validation=False) chat_session.send_message(context_prompt) else: raise ValueError('Invalid Model Specified') if session_history is None or not session_history: concated_questions = None re_written_qe = None previous_question = None previous_sql = None else: concated_questions,re_written_qe=self.rewrite_question(user_question,session_history) previous_question, previous_sql = self.get_last_sql(session_history) build_context_prompt=f""" Below is the previous user question from this conversation and its generated sql. Previous Question: {previous_question} Previous Generated SQL : {previous_sql} Respond with Generate SQL for User Question : {user_question} """ # print("BUILD CONTEXT ::: "+str(build_context_prompt)) with telemetry.tool_context_manager('opendataqna-buildsql-v2'): response = chat_session.send_message(build_context_prompt, stream=False) generated_sql = (str(response.text)).replace("```sql", "").replace("```", "") generated_sql = (str(response.text)).replace("```sql", "").replace("```", "") # print(generated_sql) return generated_sql def rewrite_question(self,question,session_history): formatted_history='' concat_questions='' for i, _row in enumerate(session_history,start=1): user_question = _row['user_question'] sql_query = _row['bot_response'] # print(user_question) formatted_history += f"User Question - Turn :: {i} : {user_question}\n" formatted_history += f"Generated SQL - Turn :: {i}: {sql_query}\n\n" concat_questions += f"{user_question} " # print(formatted_history) context_prompt = f""" Your main objective is to rewrite and refine the question passed based on the session history of question and sql generated. Refine the given question using the provided session history to produce a queryable statement. The refined question should be self-contained, requiring no additional context for accurate SQL generation. Make sure all the information is included in the re-written question Below is the previous session history: {formatted_history} Question to rewrite: {question} """ re_written_qe = str(self.generate_llm_response(context_prompt)) print("*"*25 +"Re-written question for the follow up:: "+"*"*25+"\n"+str(re_written_qe)) return str(concat_questions),str(re_written_qe) def get_last_sql(self,session_history): for entry in reversed(session_history): if entry.get("bot_response"): return entry["user_question"],entry["bot_response"] return None