nl2sql_library/app.py (247 lines of code) (raw):
"""
Main file serving the Executors modules exposing the APIs
for Linear Executor, Chain of Thought executor, RAG Executor
Updating user feedback etc
"""
# from nl2sql_lib_executors import NL2SQL_Executors
import sys
import inspect
import json
import os
from flask_cors import CORS
from flask import Flask, request
from dotenv import load_dotenv
from loguru import logger
from utils.utility_functions import initialize_db, config_project
from utils.utility_functions import execute_bq_query, log_update_feedback
from utils.utility_functions import result2nl, get_project_config, log_sql
from nl2sql_query_embeddings import Nl2Sql_embed
load_dotenv()
currentdir = os.path.dirname(
os.path.abspath(inspect.getfile(inspect.currentframe()))
)
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
app = Flask(__name__)
cors = CORS(app)
app.config["CORS_HEADERS"] = "Content-Type"
dataset_name = get_project_config()["config"]["dataset"] # "zoominfo"
# bigquery_connection_string = "bigquery://sl-test-project-363109/zoominfo"
bigquery_connection_string = initialize_db(
get_project_config()["config"]["proj_name"],
get_project_config()["config"]["dataset"],
)
data_file_name = get_project_config()["config"]["metadata_file"]
f = open(f"utils/{data_file_name}", encoding="utf-8")
spider_data = json.load(f)
data_dictionary_read = {
"nl2sql_spider": {
"description": "This dataset contains information about the concerts\
singers, country they belong to, stadiums where the \
concerts happened",
"tables": spider_data,
},
}
print("curr path = ", os.getcwd())
@app.route("/")
def spec():
"""
Default API Route
"""
return json.dumps({"response": "Multi mode NL2SQL Generation library"})
@app.route("/api/executor/linear", methods=["POST"])
def linear_executor():
"""
Invokes the Linear Executor
"""
question = request.json["question"]
execute_sql = request.json["execute_sql"]
logger.info(f"Linear Execution engine for question : [{question}]")
from nl2sql_lib_executors import NL2SQL_Executors
try:
nle = NL2SQL_Executors()
res_id, sql, df = nle.linear_executor(
question=question, data_dict=data_dictionary_read
)
sql_result = ""
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"error_msg": "",
}
log_sql(res_id, question, sql, "Linear Executor", execute_sql)
if execute_sql:
try:
result = execute_bq_query(sql)
sql_result = result2nl(question, result)
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"df": df.to_json(),
"error_msg": "",
}
except RuntimeError:
print("internal try catch")
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"df": df.to_json(),
"error_msg": "",
}
except RuntimeError:
logger.debug(f"Linear SQL Generation uncussessful : [{question}]")
response_string = {
"result_id": 0,
"generated_query": "",
"sql_result": "",
"error_msg": "Error encountered in Linear executor",
}
return response_string
@app.route("/api/executor/cot", methods=["POST"])
def cot_executor():
"""
Invokes the Chain of Thought executor
"""
question = request.json["question"]
execute_sql = request.json["execute_sql"]
logger.info("CoT SQL Generation engine for question : [{question}]")
from nl2sql_lib_executors import NL2SQL_Executors
try:
logger.info("CoT initialising the class")
nle = NL2SQL_Executors()
res_id, sql, df = nle.cot_executor(
question=question, data_dict=data_dictionary_read
)
sql_result = ""
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"error_msg": "",
}
sql2 = "\t".join([line.strip() for line in sql])
log_sql(res_id, question, str(sql2), "CoT Executor", execute_sql)
if execute_sql:
try:
result = execute_bq_query(sql)
sql_result = result2nl(question, result)
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"df": df.to_json(),
"error_msg": "",
}
except RuntimeError:
print("internal try catch")
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"df": df.to_json(),
"error_msg": "",
}
except RuntimeError:
logger.debug(f"CoT SQL generation unsuccessful : [{question}]")
response_string = {
"result_id": 0,
"generated_query": "",
"sql_result": "",
"error_msg": "Error encountered in CoT executor",
}
return response_string
@app.route("/api/executor/rag", methods=["POST"])
def rag_executor():
"""
Invokes the RAG Executor
"""
question = request.json["question"]
execute_sql = request.json["execute_sql"]
logger.info("RAG SQL Generation engine for question : [{question}]")
from nl2sql_lib_executors import NL2SQL_Executors
try:
nle = NL2SQL_Executors()
res_id, sql = nle.rag_executor(question=question)
# res_id, sql = nle.generate_query(question)
sql_result = ""
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"error_msg": "",
}
log_sql(res_id, question, sql, "Rag Executor", execute_sql)
if execute_sql:
try:
result = execute_bq_query(sql)
sql_result = result2nl(question, result)
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"error_msg": "",
}
except RuntimeError:
print("internal try catch")
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": sql_result,
"error_msg": "",
}
except RuntimeError:
logger.debug(f"RAG SQL generation unsuccessful : [{question}]")
response_string = {
"result_id": 0,
"generated_query": "",
"sql_result": "",
"error_msg": "Error encountered in RAG executor",
}
return json.dumps(response_string)
@app.route("/projconfig", methods=["POST"])
def project_config():
"""
Updates the Project Configuration details
"""
logger.info("Updating project configuration")
project = request.json["proj_name"]
dataset = request.json["bq_dataset"]
metadata_file = request.json["metadata_file"]
logger.info(f"Received info - {project}, {dataset}, {metadata_file}")
config_project(project, dataset, metadata_file)
return json.dumps({"status": "success"})
@app.route("/uploadfile", methods=["POST"])
def upload_file():
"""
Saves the data dictionary / metadata cache data
received over HTTP request into a file
"""
logger.info("File received")
try:
file = request.files["file"]
data = file.read()
my_json = data.decode("utf8")
data2 = json.loads(my_json)
data_to_save = json.dumps(data2, indent=4)
target_file = get_project_config()["config"]["metadata_file"]
logger.info(f"Saving file as : {target_file}")
with open(f"utils/{target_file}", "w", encoding="utf-8") as outfile:
outfile.write(data_to_save)
logger.info(f"List of files - {os.listdir('utils')}")
return json.dumps({"status": "Successfully uploaded file"})
except RuntimeError:
return json.dumps({"status": "Failed to upload file"})
@app.route("/userfb", methods=["POST"])
def user_feedback():
"""
Updates the User feedback sent from UI
"""
logger.info("Updating user feedback")
result_id = request.json["result_id"]
feedback = request.json["user_feedback"]
try:
log_update_feedback(result_id, feedback)
return json.dumps({"response": "successfully updated user feedback"})
except RuntimeError:
return json.dumps({"response": "failed to update user feedback"})
@app.route("/execsql", methods=["POST"])
def execute_sql_query():
"""
Executes the query on BQ
"""
sql = request.json["sql"]
result = execute_bq_query(sql)
print("result = ", result)
sql_result = result.to_dict() # orient="records")
res_id = ""
result_text = result2nl("", sql_result)
response_string = {
"result_id": res_id,
"generated_query": sql,
"sql_result": result_text,
"error_msg": "",
}
return json.dumps(response_string)
@app.route('/api/record/create', methods=['POST'])
def create_record():
"""
Insert record with Question and MappedSQL in the Table or Local file
"""
question = request.json['question']
mappedsql = request.json['sql']
logger.info(f"Inserting data. Input : {question} and {mappedsql}")
try:
# pge = PgSqlEmb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD)
# pge.insert_row(question, mappedsql)
embed = Nl2Sql_embed()
embed.insert_data(question=question, sql=mappedsql)
return json.dumps({"response": "Successfully inserted record"})
except RuntimeError:
return json.dumps({"response": "Unable to insert record"})
if __name__ == "__main__":
app.run(debug=True, port=5000)