in nl2sql_src/nl2sql_generic.py [0:0]
def generate_sql_few_shot(self,
question,
table_name=None,
logger_file="log.txt"):
"""
Main function which converts NL to SQL using few shot prompting
"""
# step-1 table selection
try:
if not table_name:
if len(self.metadata_json.keys()) > 1:
table_list = self.table_filter(question)
table_name = table_list[0]
else:
table_name = list(self.metadata_json.keys())[0]
table_json = self.metadata_json[table_name]
columns_json = table_json["Columns"]
columns_info = ""
for column_name in columns_json:
column = columns_json[column_name]
column_info = f"""{column["Name"]} \
({column["Type"]}) : {column["Description"]}.\
{column["Examples"]}\n"""
columns_info = columns_info + column_info
# few_shot_json = self.pge.search_matching_queries(question)
embed = Nl2Sql_embed()
few_shot_json = embed.search_matching_queries(question)
logger.info(f"Few sjot examples : {few_shot_json}")
few_shot_examples = ""
for item in few_shot_json:
example_string = f"Question: {item['question']}"
few_shot_examples += example_string + "\n"
example_string = f"SQL : {item['sql']} "
few_shot_examples += example_string + "\n\n"
sql_prompt = Sql_Generation_prompt_few_shot.format(
table_name=table_json["Name"],
table_description=table_json["Description"],
columns_info=columns_info,
few_shot_examples=few_shot_examples,
question=question)
response = self.llm.invoke(sql_prompt)
sql_query = response.replace('sql', '').replace('```', '')
# sql_query = self.case_handler_transform(sql_query)
sql_query = self.add_dataset_to_query(sql_query)
with open(logger_file, 'a', encoding="utf-8") as f:
f.write(f">>\nModel:{self.model_name} \n\nQuestion: {question}\
\n\nPrompt:{sql_prompt} \n\nSql_query:{sql_query}<\n")
if sql_query.strip().startswith("Response:"):
sql_query = sql_query.split(":")[1].strip()
return sql_query
except Exception as exc:
raise Exception(traceback.print_exc()) from exc