nl2sql_library/nl2sql_lib_executors.py (115 lines of code) (raw):

""" Wrapper file for invoking the Executors from the app.py file """ import json import sys import vertexai from loguru import logger from nl2sql.executors.linear_executor.core import CoreLinearExecutor from utils.utility_functions import get_project_config, initialize_db from nl2sql.llms.vertexai import text_bison_32k from nl2sql.tasks.table_selection.core import CoreTableSelector from nl2sql.tasks.table_selection.core import prompts as cts_prompts from nl2sql.tasks.column_selection.core import ( CoreColumnSelector, prompts as ccs_prompts, ) from nl2sql.tasks.sql_generation.core import CoreSqlGenerator from nl2sql.tasks.sql_generation.core import prompts as csg_prompts # from sample_executors.rag_executor import RAG_Executor vertexai.init( project=get_project_config()["config"]["proj_name"], location="us-central1" ) dataset_name = get_project_config()["config"]["dataset"] # "nl2sql_spider" bigquery_connection_string = initialize_db( get_project_config()["config"]["proj_name"], get_project_config()["config"]["dataset"], ) data_file_name = get_project_config()["config"]["metadata_file"] logger.info( f"Data = {bigquery_connection_string}, {dataset_name}, {data_file_name}" ) question_to_gen = "What is the average, minimum, and maximum age\ of all singers from France ?" class NL2SQL_Executors: """ Class with wrapper functions for all Executors """ def __init__(self): self.executor = "" def linear_executor( self, question=question_to_gen, bq_conn_string=bigquery_connection_string, data_dict=None, ): """ SQL Generation using Linear Executor """ executor_linear = CoreLinearExecutor.from_connection_string_map( { dataset_name: bq_conn_string, }, data_dictionary=data_dict, ) logger.info("Linear Executor Executing for ", question) result = executor_linear(db_name=dataset_name, question=question) df = executor_linear.fetch_result(result) logger.info(f"Linear executor output: [{result.generated_query}]") return result.result_id, result.generated_query, df def cot_executor( self, question=question_to_gen, bq_conn_string=bigquery_connection_string, data_dict=None, ): """ SQL Generation using Chain of Thought Executor """ logger.info("Chain of Thought executor LLM initialised") llm = text_bison_32k() # Disabling logs because these steps generate a LOT of logs. logger.disable("nl2sql.datasets.base") core_table_selector = CoreTableSelector( llm=llm, prompt=cts_prompts.CURATED_FEW_SHOT_COT_PROMPT ) core_column_selector = CoreColumnSelector( llm=llm, prompt=ccs_prompts.CURATED_FEW_SHOT_COT_PROMPT ) core_sql_generator = CoreSqlGenerator( llm=llm, prompt=csg_prompts.CURATED_FEW_SHOT_COT_PROMPT ) logger.enable("nl2sql.datasets.base") executor_cot = CoreLinearExecutor.from_connection_string_map( {dataset_name: bq_conn_string}, core_table_selector=core_table_selector, core_column_selector=core_column_selector, core_sql_generator=core_sql_generator, data_dictionary=data_dict, ) logger.info("Chain of Thought Executor executing for ", question) result = executor_cot(db_name=dataset_name, question=question) df = executor_cot.fetch_result(result) logger.info(f"Chain of Thought output: [{result.generated_query}]") return result.result_id, result.generated_query, df # def rag_executor(self, question=question_to_gen): # """ # SQL Generation using RAG Executor # """ # ragexec = RAG_Executor() # res_id, sql = ragexec.generate_sql(question) # return res_id, sql if __name__ == "__main__": if len(sys.argv) < 2: print("Usage : python nl2sql_lib_executors.py <executor name>") print("Types of Executors : linear, cot") print("Default is Linear executor") executor = "linear" elif sys.argv[1] == "linear": executor = "linear" elif sys.argv[1] == "cot": executor = "cot" else: print("Invalid executor type") print("Defaulting to Linear exeutor") executor = "linear" f = open(f"utils/{data_file_name}", encoding="utf-8") spider_data = json.load(f) data_dictionary_read = { "nl2sql_spider": { "description": "This dataset contains information about the \ concerts singers, country they belong to, \ stadiums where the \ concerts happened", "tables": spider_data, }, } nle = NL2SQL_Executors() if executor == "linear": result_id, gen_sql = nle.linear_executor( data_dict=data_dictionary_read ) print("linear") elif executor == "cot": result_id, gen_sql = nle.cot_executor(data_dict=data_dictionary_read) print("cot") print("Generated SQL = ", gen_sql)