in nl2sql_src/nl2sql_generic.py [0:0]
def generate_sql(self, question, table_name=None, logger_file="log.txt"):
"""
Main function which converts NL to SQL
"""
# step-1 table selection
try:
if not table_name:
if len(self.metadata_json.keys()) > 1:
table_list = self.table_filter(question)
table_name = table_list[0]
else:
table_name = list(self.metadata_json.keys())[0]
table_json = self.metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}
\n"""
columns_info = columns_info + column_info
sql_prompt = Sql_Generation_prompt.format(
table_name=table_json["Name"],
table_description=table_json["Description"],
columns_info=columns_info,
question=question
)
response = self.llm.invoke(sql_prompt)
sql_query = response.replace('sql', '').replace('```', '')
# sql_query = self.case_handler_transform(sql_query)
sql_query = self.add_dataset_to_query(sql_query)
with open(logger_file, 'a', encoding="utf-8") as f:
f.write(f">>\nModel:{self.model_name} \n\nQuestion: {question}\
\n\nPrompt:{sql_prompt} \nSql_query:{sql_query}<<\n")
if sql_query.strip().startswith("Response:"):
sql_query = sql_query.split(":")[1].strip()
return sql_query
except Exception as exc:
raise Exception(traceback.print_exc()) from exc