backend-apis/main.py (343 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flask import Flask, request, jsonify, render_template, Response import asyncio from collections.abc import Callable import logging as log import json import datetime import urllib import re import time import textwrap import pandas as pd from flask_cors import CORS import os import sys import firebase_admin from firebase_admin import credentials, auth from functools import wraps firebase_admin.initialize_app() from opendataqna import get_all_databases,get_kgq,generate_sql,embed_sql,get_response,get_results,visualize module_path = os.path.abspath(os.path.join('.')) sys.path.append(module_path) def jwt_authenticated(func: Callable[..., int]) -> Callable[..., int]: @wraps(func) async def decorated_function(*args, **kwargs): header = request.headers.get("Authorization", None) if header: token = header.split(" ")[1] try: print("TOKEN::"+str(token)) decoded_token = firebase_admin.auth.verify_id_token(token) except Exception as e: log.exception(e) return Response(status=403, response=f"Error with authentication: {e}") else: return Response(status=401) request.uid = decoded_token["uid"] print("USER:: "+str(request.uid)) return await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs) return decorated_function RUN_DEBUGGER = True DEBUGGING_ROUNDS = 2 LLM_VALIDATION = False EXECUTE_FINAL_SQL = True Embedder_model = 'vertex' SQLBuilder_model = 'gemini-1.5-pro' SQLChecker_model = 'gemini-1.5-pro' SQLDebugger_model = 'gemini-1.5-pro' num_table_matches = 5 num_column_matches = 10 table_similarity_threshold = 0.3 column_similarity_threshold = 0.3 example_similarity_threshold = 0.3 num_sql_matches = 3 app = Flask(__name__) cors = CORS(app, resources={r"/*": {"origins": "*"}}) @app.route("/available_databases", methods=["GET"]) # @jwt_authenticated def getBDList(): result,invalid_response=get_all_databases() if not invalid_response: responseDict = { "ResponseCode" : 200, "KnownDB" : result, "Error":"" } else: responseDict = { "ResponseCode" : 500, "KnownDB" : "", "Error":result } return jsonify(responseDict) @app.route("/embed_sql", methods=["POST"]) # @jwt_authenticated async def embedSql(): envelope = str(request.data.decode('utf-8')) envelope=json.loads(envelope) user_grouping=envelope.get('user_grouping') generated_sql = envelope.get('generated_sql') user_question = envelope.get('user_question') session_id = envelope.get('session_id') embedded, invalid_response=await embed_sql(session_id,user_grouping,user_question,generated_sql) if not invalid_response: responseDict = { "ResponseCode" : 201, "Message" : "Example SQL has been accepted for embedding", "SessionID" : session_id, "Error":"" } return jsonify(responseDict) else: responseDict = { "ResponseCode" : 500, "KnownDB" : "", "SessionID" : session_id, "Error":embedded } return jsonify(responseDict) @app.route("/run_query", methods=["POST"]) # @jwt_authenticated def getSQLResult(): envelope = str(request.data.decode('utf-8')) envelope=json.loads(envelope) user_question = envelope.get('user_question') user_grouping = envelope.get('user_grouping') generated_sql = envelope.get('generated_sql') session_id = envelope.get('session_id') result_df,invalid_response=get_results(user_grouping,generated_sql) if not invalid_response: _resp,invalid_response=get_response(session_id,user_question,result_df.to_json(orient='records')) if not invalid_response: responseDict = { "ResponseCode" : 200, "KnownDB" : result_df.to_json(orient='records'), "NaturalResponse" : _resp, "SessionID" : session_id, "Error":"" } else: responseDict = { "ResponseCode" : 500, "KnownDB" : result_df.to_json(orient='records'), "NaturalResponse" : _resp, "SessionID" : session_id, "Error":"" } else: _resp=result_df responseDict = { "ResponseCode" : 500, "KnownDB" : "", "NaturalResponse" : _resp, "SessionID" : session_id, "Error":result_df } return jsonify(responseDict) @app.route("/get_known_sql", methods=["POST"]) # @jwt_authenticated def getKnownSQL(): print("Extracting the known SQLs from the example embeddings.") envelope = str(request.data.decode('utf-8')) envelope=json.loads(envelope) user_grouping = envelope.get('user_grouping') result,invalid_response=get_kgq(user_grouping) if not invalid_response: responseDict = { "ResponseCode" : 200, "KnownSQL" : result, "Error":"" } else: responseDict = { "ResponseCode" : 500, "KnownSQL" : "", "Error":result } return jsonify(responseDict) @app.route("/generate_sql", methods=["POST"]) # @jwt_authenticated async def generateSQL(): print("Here is the request payload ") envelope = str(request.data.decode('utf-8')) print("Here is the request payload " + envelope) envelope=json.loads(envelope) user_question = envelope.get('user_question') user_grouping = envelope.get('user_grouping') session_id = envelope.get('session_id') user_id = envelope.get('user_id') generated_sql,session_id,invalid_response = await generate_sql(session_id, user_question, user_grouping, RUN_DEBUGGER, DEBUGGING_ROUNDS, LLM_VALIDATION, Embedder_model, SQLBuilder_model, SQLChecker_model, SQLDebugger_model, num_table_matches, num_column_matches, table_similarity_threshold, column_similarity_threshold, example_similarity_threshold, num_sql_matches, user_id=user_id) if not invalid_response: responseDict = { "ResponseCode" : 200, "GeneratedSQL" : generated_sql, "SessionID" : session_id, "Error":"" } else: responseDict = { "ResponseCode" : 500, "GeneratedSQL" : "", "SessionID" : session_id, "Error":generated_sql } return jsonify(responseDict) @app.route("/generate_viz", methods=["POST"]) # @jwt_authenticated async def generateViz(): envelope = str(request.data.decode('utf-8')) # print("Here is the request payload " + envelope) envelope=json.loads(envelope) user_question = envelope.get('user_question') generated_sql = envelope.get('generated_sql') sql_results = envelope.get('sql_results') session_id = envelope.get('session_id') chart_js='' try: chart_js, invalid_response = visualize(session_id,user_question,generated_sql,sql_results) if not invalid_response: responseDict = { "ResponseCode" : 200, "GeneratedChartjs" : chart_js, "Error":"", "SessionID":session_id } else: responseDict = { "ResponseCode" : 500, "GeneratedSQL" : "", "SessionID":session_id, "Error": chart_js } return jsonify(responseDict) except Exception as e: # util.write_log_entry("Cannot generate the Visualization!!!, please check the logs!" + str(e)) responseDict = { "ResponseCode" : 500, "GeneratedSQL" : "", "SessionID":session_id, "Error":"Issue was encountered while generating the Google Chart, please check the logs!" + str(e) } return jsonify(responseDict) @app.route("/summarize_results", methods=["POST"]) # @jwt_authenticated async def getSummary(): envelope = str(request.data.decode('utf-8')) envelope=json.loads(envelope) user_question = envelope.get('user_question') sql_results = envelope.get('sql_results') result,invalid_response=get_response(user_question,sql_results) if not invalid_response: responseDict = { "ResponseCode" : 200, "summary_response" : result, "Error":"" } else: responseDict = { "ResponseCode" : 500, "summary_response" : "", "Error":result } return jsonify(responseDict) @app.route("/natural_response", methods=["POST"]) # @jwt_authenticated async def getNaturalResponse(): envelope = str(request.data.decode('utf-8')) #print("Here is the request payload " + envelope) envelope=json.loads(envelope) user_question = envelope.get('user_question') user_grouping = envelope.get('user_grouping') generated_sql,session_id,invalid_response = await generate_sql(user_question, user_grouping, RUN_DEBUGGER, DEBUGGING_ROUNDS, LLM_VALIDATION, Embedder_model, SQLBuilder_model, SQLChecker_model, SQLDebugger_model, num_table_matches, num_column_matches, table_similarity_threshold, column_similarity_threshold, example_similarity_threshold, num_sql_matches) if not invalid_response: result_df,invalid_response=get_results(user_grouping,generated_sql) if not invalid_response: result,invalid_response=get_response(user_question,result_df.to_json(orient='records')) if not invalid_response: responseDict = { "ResponseCode" : 200, "summary_response" : result, "Error":"" } else: responseDict = { "ResponseCode" : 500, "summary_response" : "", "Error":result } else: responseDict = { "ResponseCode" : 500, "KnownDB" : "", "Error":result_df } else: responseDict = { "ResponseCode" : 500, "GeneratedSQL" : "", "Error":generated_sql } return jsonify(responseDict) @app.route("/get_results", methods=["POST"]) async def getResultsResponse(): envelope = str(request.data.decode('utf-8')) #print("Here is the request payload " + envelope) envelope=json.loads(envelope) user_question = envelope.get('user_question') user_database = envelope.get('user_database') generated_sql,invalid_response = await generate_sql(user_question, user_database, RUN_DEBUGGER, DEBUGGING_ROUNDS, LLM_VALIDATION, Embedder_model, SQLBuilder_model, SQLChecker_model, SQLDebugger_model, num_table_matches, num_column_matches, table_similarity_threshold, column_similarity_threshold, example_similarity_threshold, num_sql_matches) if not invalid_response: result_df,invalid_response=get_results(user_database,generated_sql) if not invalid_response: responseDict = { "ResponseCode" : 200, "GeneratedResults" : result_df.to_json(orient='records'), "Error":"" } else: responseDict = { "ResponseCode" : 500, "GeneratedResults" : "", "Error":result_df } else: responseDict = { "ResponseCode" : 500, "GeneratedResults" : "", "Error":generated_sql } return jsonify(responseDict) if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))