nl2sql_src/app.py (139 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Main file serving the Executors modules exposing the APIs for Linear Executor, Chain of Thought executor, RAG Executor Updating user feedback etc """ # from nl2sql_lib_executors import NL2SQL_Executors import json import os from flask_cors import CORS from flask import Flask, request from dotenv import load_dotenv from loguru import logger import uuid from nl2sql_generic import Nl2sqlBq from utils.utility_functions import config_project, get_project_config from utils.utility_functions import log_sql, log_update_feedback from nl2sql_query_embeddings import Nl2Sql_embed PROJECT_ID = 'sl-test-project-363109' LOCATION = 'us-central1' DATASET_ID = 'nl2sql_spider' load_dotenv() app = Flask(__name__) cors = CORS(app) app.config["CORS_HEADERS"] = "Content-Type" dataset_name = 'nl2sql_spider' @app.route("/") def spec(): """ Default API Route """ logger.info("Welcome to NL2SQL Lite") return json.dumps({"response": "NL2SQL Lite Generation library"}) @app.route("/api/lite/generate", methods=["POST"]) def nl2sql_lite_generate(): """ Invokes the NL2SQL Lite SQL Generator """ question = request.json["question"] execute_sql = request.json["execute_sql"] few_shot = request.json["few_shot"] logger.info(f"NL2SQL Lite engine for question : [{question}]") try: logger.info("Reading the configuration file") curdir = os.getcwd() proj_conf = get_project_config()["config"] print(proj_conf) data_file_name = proj_conf["metadata_file"] logger.info(f"Using the metadata file : {data_file_name}") metadata_json_path = f"{curdir}/utils/{data_file_name}" logger.info(f"path {metadata_json_path}") nl2sqlbq_client_base = Nl2sqlBq(project_id=proj_conf["proj_name"], dataset_id=proj_conf["dataset"], metadata_json_path=metadata_json_path, model_name="text-bison@002", tuned_model=False) if few_shot: logger.info("NL2SQL Studio Lite - Few shot SQL generation") sql = nl2sqlbq_client_base.generate_sql_few_shot(question) else: logger.info("NL2SQL Studio Lite - SQL generation") sql = nl2sqlbq_client_base.generate_sql(question) logger.info(f"NL2SQL Studio Lite generated SQL = {sql}") sql_result = "" res_id = str(uuid.uuid4()) # "lite" print(res_id) response_string = { "result_id": res_id, "generated_query": sql, "sql_result": sql_result, "error_msg": "", } log_sql(res_id, question, sql, "Lite", False) if execute_sql: try: results = nl2sqlbq_client_base.execute_query(sql) sql_result = nl2sqlbq_client_base.result2nl(result=results, question=question) response_string = { "result_id": res_id, "generated_query": sql, "sql_result": sql_result, "error_msg": "", } except Exception: logger.error("Error executing the query on BigQuery") response_string = { "result_id": res_id, "generated_query": sql, "sql_result": sql_result, "error_msg": "Error - NL2SQL Studio Lite Query Generation", } except Exception as e: logger.error( f"NL2SQL Lite SQL Generation unsuccessful: [{question}] {e}" ) response_string = { "result_id": 0, "generated_query": "", "sql_result": "", "error_msg": "Error in NL2SQL Studio Lite Query generation", } return json.dumps(response_string) @app.route("/projconfig", methods=["POST"]) def project_config(): """ Updates the Project Configuration details """ logger.info("Updating project configuration") project = request.json["proj_name"] dataset = request.json["bq_dataset"] metadata_file = request.json["metadata_file"] logger.info(f"Received info - {project}, {dataset}, {metadata_file}") config_project(project, dataset, metadata_file) return json.dumps({"status": "success"}) @app.route("/uploadfile", methods=["POST"]) def upload_file(): """ Saves the data dictionary / metadata cache data received over HTTP request into a file """ logger.info("File received") try: file = request.files["file"] data = file.read() my_json = data.decode("utf8") data2 = json.loads(my_json) data_to_save = json.dumps(data2, indent=4) target_file = get_project_config()["config"]["metadata_file"] logger.info(f"Saving file as : {target_file}") with open(f"utils/{target_file}", "w", encoding="utf-8") as outfile: outfile.write(data_to_save) logger.info(f"List of files - {os.listdir('utils')}") return json.dumps({"status": "Successfully uploaded file"}) except RuntimeError: return json.dumps({"status": "Failed to upload file"}) @app.route("/userfb", methods=["POST"]) def user_feedback(): """ Updates the User feedback sent from UI """ logger.info("Updating user feedback") result_id = request.json["result_id"] feedback = request.json["user_feedback"] try: log_update_feedback(result_id, feedback) return json.dumps({"response": "successfully updated user feedback"}) except RuntimeError: return json.dumps({"response": "failed to update user feedback"}) @app.route('/api/record/create', methods=['POST']) def create_record(): """ Insert record with Question and MappedSQL in the Table or Local file """ question = request.json['question'] mappedsql = request.json['sql'] logger.info(f"Inserting data. Input : {question} and {mappedsql}") try: # pge = PgSqlEmb(PGPROJ, PGLOCATION, PGINSTANCE, PGDB, PGUSER, PGPWD) # pge.insert_row(question, mappedsql) embed = Nl2Sql_embed() embed.insert_data(question=question, sql=mappedsql) return json.dumps({"response": "Successfully inserted record"}) except RuntimeError: return json.dumps({"response": "Unable to insert record"}) if __name__ == "__main__": app.run(debug=True, port=5000)