dbai_src/dbai.py (271 lines of code) (raw):
"""The main module for the NL2SQL chat Agent which is multi-turn."""
import os
import json
import vertexai
from google.cloud import bigquery
from vertexai import generative_models
from vertexai.generative_models import (
GenerativeModel,
Part,
Tool,
# ToolConfig
)
import streamlit as st
from bot_functions import (
list_tables_func,
get_table_metadata_func,
sql_query_func,
plot_chart_auto_func
)
safety_settings = {
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH:
generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT:
generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
gemini = GenerativeModel(
"gemini-1.5-pro-001",
generation_config={"temperature": 0.05},
safety_settings=safety_settings,
)
class Response:
"""The base response template class for DBAI output"""
def __init__(self, text, interim_steps) -> None:
self.text = text
self.interim_steps = interim_steps
class DBAI:
"""
The base class for DBAI agent which is the multi-turn chat
and can plot graphs.
"""
def __init__(
self,
proj_id="proj-kous",
dataset_id="Albertsons",
tables_list=['camain_oracle_hcm', 'camain_ps']
):
self.proj_id = proj_id
self.dataset_id = dataset_id
self.tables_list = tables_list
self.sql_query_tool = Tool(
function_declarations=[
list_tables_func,
get_table_metadata_func,
sql_query_func,
plot_chart_auto_func,
# plot_chart_func,
],
)
self.agent = GenerativeModel("gemini-1.5-pro-001",
generation_config={"temperature": 0.05},
safety_settings=safety_settings,
tools=[self.sql_query_tool],
)
self.bq_client = bigquery.Client(project=self.proj_id)
self.system_prompt = """
You are a fluent person who efficiently communicates with the user over
different Database queries. Please always call the functions at your disposal
whenever you need to know something, and do not reply unless you feel you have
all information to answer the question satisfactorily.
Only use information that you learn from BigQuery, do not make up information.
Always use date or time functions instead of hard-coded values in SQL
to reflect true current value.
"""
self.load_metadata()
vertexai.init(project=self.proj_id)
def load_metadata(self):
"""
Load the metadata cache file from the defined path if exists
else creates.
"""
metdata_cache_path = f"./metadata_cache_{self.dataset_id}.json"
if not os.path.exists(metdata_cache_path):
self.metadata = self.create_metadata_cache()
with open(metdata_cache_path, 'w') as f:
# pylint: disable=unspecified-encoding
f.write(json.dumps(self.metadata))
else:
with open(metdata_cache_path, 'r') as f:
# pylint: disable=unspecified-encoding
self.metadata = json.load(f)
def create_metadata_cache(self):
"""
create the metadata cache file for the specified Tables in DB
for all columns.
"""
gen_description_prompt = """
Based on the columns information of this table.
Generate a very brief description for this table.
TABLE: {table_id}
columns_info: {columns_info}
"""
if self.tables_list in [[], [''], '']:
api_response = self.bq_client.list_tables(self.dataset_id)
self.tables_list = [table.table_id for table in api_response]
metadata = {}
for table_id in self.tables_list:
columns_info = self.bq_client.get_table(
f'{self.dataset_id}.{table_id}'
).to_api_repr()['schema']
# remove unwanted details like 'mode'
for field in columns_info.get('fields', []):
field.pop('mode', None)
metadata[table_id] = {}
metadata[table_id]["table_name"] = table_id
metadata[table_id]["columns_info"] = columns_info
prompt = gen_description_prompt.format(
table_id=table_id,
columns_info=columns_info
)
metadata[table_id]["table_description"] = gemini.generate_content(
prompt
).text
return metadata
def api_list_tables(self):
"""Gemini Tool for listing all tables info. """
# api_response = client.list_tables(DATASET_ID)
# api_response = str([table.table_id for table in api_response])
try:
api_response = self.metadata
except Exception: # pylint: disable=broad-except
api_response = self.tables_list
return api_response
def api_get_table_metadata(self, table_id):
"""Gemini Tool to fetch metadata for the given Table"""
try:
table_metadata = str(self.metadata[table_id])
except Exception: # pylint: disable=broad-except
# if table_id is in form of dataset_id.table_id
# then remove dataset_id
table_metadata = str(self.metadata[table_id.split('.')[-1]])
return table_metadata
def execute_sql_query(self, query):
"""Gemini Tool to execute given SQL and return execution result. """
job_config = bigquery.QueryJobConfig(
default_dataset=f'{self.proj_id}.{self.dataset_id}'
)
try:
# cleaned_query = query.replace(
# "\\n", " ").replace("\n", "").replace("\\", "")
cleaned_query = query.replace("\\n", " ").replace("\n", "")
cleaned_query = cleaned_query.replace.replace("\\", "")
query_job = self.bq_client.query(
cleaned_query,
job_config=job_config
)
api_response = query_job.result()
api_response = str([dict(row) for row in api_response])
api_response = api_response.replace("\\", "").replace("\n", "")
except Exception as e: # pylint: disable=broad-except
api_response = f"{str(e)}"
return api_response
# def api_plot_chart(self, plot_params):
# data = plot_params['data']
# if isinstance(data, list):
# data = data[0]
# print('_'*100, data, '_'*100)
# data = data.replace('None', '-1')
# if 'content' in str(data):
# df = pd.DataFrame(json.loads(data['content'][0]))
# elif type(data) == str:
# df = pd.DataFrame(eval(data))
# else:
# df = pd.DataFrame(json.loads(str(data)))
# fig = px.bar(df, x=plot_params['x_axis'], y=plot_params['y_axis']
# , title=plot_params['title'])
# return fig
# def api_plot_chart_auto(code):
# fig = eval(code)
# return fig
def format_interim_steps(self, interim_steps):
"""Format all the intermediate steps for showing them in UI."""
detailed_log = ""
for i in interim_steps:
detailed_log += f'''### Function call:\n
##### Function name:
```
{str(i['function_name'])}
```
\n\n
##### Function parameters:
```
{str(i['function_params'])}
```
\n\n
##### API response:
```
{str(i['API_response'])}
```
\n\n'''
return detailed_log
def ask(self, question, chat):
"""main interface for interacting in multi-turn chat mode. """
prompt = question + \
f"\n The dataset_id is {self.dataset_id}" + \
self.system_prompt
response = chat.send_message(prompt)
response = response.candidates[0].content.parts[0]
intermediate_steps = []
function_calling_in_process = True
while function_calling_in_process:
try:
function_name, params = response.function_call.name, {}
for key, value in response.function_call.args.items():
params[key] = value
api_response = ''
if function_name == "list_tables":
api_response = self.api_list_tables()
if function_name == "get_table_metadata":
api_response = self.api_get_table_metadata(
params["table_id"]
)
if function_name == "sql_query":
api_response = self.execute_sql_query(params["query"])
# if function_name == "plot_chart":
# fig = api_plot_chart(params)
# st.plotly_chart(fig)#, use_container_width=True)
# api_response = "here is the plot of the data."
if function_name == "plot_chart_auto":
print(type(params['code']), params['code'])
local_namespace = {}
# Execute the code string in the local namespace
exec( # pylint: disable=exec-used
params['code'].replace('\r\n', '\n'),
globals(),
local_namespace
)
# Access the 'fig' variable from the local namespace
fig = local_namespace['fig']
st.plotly_chart(fig) # use_container_width=True)
api_response = "here is the plot of the data shown below \
in separate tab."
response = chat.send_message(
Part.from_function_response(
name=function_name,
response={
"content": api_response,
},
),
)
response = response.candidates[0].content.parts[0]
intermediate_steps.append({
'function_name': function_name,
'function_params': params,
'API_response': api_response,
'response': response
})
except AttributeError:
function_calling_in_process = False
return Response(text=response.text, interim_steps=intermediate_steps)
class NL2SQLResp:
"""NL2SQL output format class"""
def __init__(self, nl_output, generated_sql, sql_output) -> None:
self.nl_output = nl_output
self.generated_sql = generated_sql
self.sql_output = sql_output
def __str__(self) -> str:
return f''' NL_OUTPUT: {self.nl_output}\n
GENERATED_SQL: {self.generated_sql}\n
SQL_OUTPUT: {self.sql_output}'''
class DBAI_nl2sql(DBAI): # pylint: disable=invalid-name
"""
DBAI child class for generating NL2SQL response, instead of
multi-turn chat-agent.
"""
def __init__(
self,
proj_id="proj-kous",
dataset_id="Albertsons",
tables_list=['camain_oracle_hcm', 'camain_ps']
):
super().__init__(proj_id, dataset_id, tables_list)
self.nl2sql_tool = Tool(
function_declarations=[
list_tables_func,
get_table_metadata_func,
sql_query_func,
],
)
self.agent = GenerativeModel("gemini-1.5-pro-001",
generation_config={"temperature": 0.05},
safety_settings=safety_settings,
tools=[self.nl2sql_tool],
)
def get_sql(self, question):
"""
For given question, returns the genrated SQL, result and description
"""
chat = self.agent.start_chat()
prompt = question + \
f"\n The dataset_id is {self.dataset_id}" + \
self.system_prompt
response = chat.send_message(prompt)
response = response.candidates[0].content.parts[0]
intermediate_steps = []
function_calling_in_process = True
while function_calling_in_process:
try:
function_name, params = response.function_call.name, {}
for key, value in response.function_call.args.items():
params[key] = value
api_response = ''
if function_name == "list_tables":
api_response = self.api_list_tables()
if function_name == "get_table_metadata":
api_response = self.api_get_table_metadata(
params["table_id"]
)
if function_name == "sql_query":
api_response = self.execute_sql_query(params["query"])
response = chat.send_message(
Part.from_function_response(
name=function_name,
response={
"content": api_response,
},
),
)
response = response.candidates[0].content.parts[0]
intermediate_steps.append({
'function_name': function_name,
'function_params': params,
'API_response': api_response,
'response': response
})
except AttributeError:
function_calling_in_process = False
generated_sql, sql_output = '', ''
for i in intermediate_steps[::-1]:
if i['function_name'] == 'sql_query':
generated_sql = i['function_params']['query']
sql_output = i['API_response']
return NL2SQLResp(response.text, generated_sql, sql_output)