in dbai_src/dbai.py [0:0]
def get_sql(self, question):
"""
For given question, returns the genrated SQL, result and description
"""
chat = self.agent.start_chat()
prompt = question + \
f"\n The dataset_id is {self.dataset_id}" + \
self.system_prompt
response = chat.send_message(prompt)
response = response.candidates[0].content.parts[0]
intermediate_steps = []
function_calling_in_process = True
while function_calling_in_process:
try:
function_name, params = response.function_call.name, {}
for key, value in response.function_call.args.items():
params[key] = value
api_response = ''
if function_name == "list_tables":
api_response = self.api_list_tables()
if function_name == "get_table_metadata":
api_response = self.api_get_table_metadata(
params["table_id"]
)
if function_name == "sql_query":
api_response = self.execute_sql_query(params["query"])
response = chat.send_message(
Part.from_function_response(
name=function_name,
response={
"content": api_response,
},
),
)
response = response.candidates[0].content.parts[0]
intermediate_steps.append({
'function_name': function_name,
'function_params': params,
'API_response': api_response,
'response': response
})
except AttributeError:
function_calling_in_process = False
generated_sql, sql_output = '', ''
for i in intermediate_steps[::-1]:
if i['function_name'] == 'sql_query':
generated_sql = i['function_params']['query']
sql_output = i['API_response']
return NL2SQLResp(response.text, generated_sql, sql_output)