in nl2sql_src/nl2sql_generic.py [0:0]
def gen_and_exec_and_self_correct_sql(self,
prompt,
genai_model_name="GeminiPro",
max_tries=5,
return_all=False):
"""
Wrapper function for Standard, Multi-turn and Self Correct
approach of SQL generation
"""
tries = 0
error_messages = []
prompts = [prompt]
successful_queries = []
TEMPERATURE = 0.3
MAX_OUTPUT_TOKENS = 8192
MODEL_NAME = 'codechat-bison-32k'
code_gen_model = CodeChatModel.from_pretrained(MODEL_NAME)
model = GenerativeModel("gemini-1.0-pro")
if genai_model_name == "GeminiPro":
chat_session = model.start_chat()
else:
chat_session = CodeChatSession(model=code_gen_model,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS
)
while tries < max_tries:
try:
if genai_model_name == "GeminiPro":
response = chat_session.send_message(prompt)
else:
response = chat_session.send_message(
prompt,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS
)
generated_sql_query = response.text
generated_sql_query = '\n'.join(
generated_sql_query.split('\n')[1:-1]
)
generated_sql_query = self.case_handler_transform(
generated_sql_query
)
generated_sql_query = self.add_dataset_to_query(
generated_sql_query
)
df = client.query(generated_sql_query).to_dataframe()
successful_queries.append({
"query": generated_sql_query,
"dataframe": df
})
if len(successful_queries) > 1:
prompt = f"""Modify the last successful SQL query by
making changes to it and optimizing it for latency.
ENSURE that the NEW QUERY is DIFFERENT from the
previous one while prioritizing faster execution.
Reference the tables only from the above given
project and dataset
The last successful query was:
{successful_queries[-1]["query"]}"""
except Exception as e:
msg = str(e)
error_messages.append(msg)
prompt = f"""Encountered an error: {msg}.
To address this, please generate an alternative SQL
query response that avoids this specific error.
Follow the instructions mentioned above to
remediate the error.
Modify the below SQL query to resolve the issue and
ensure it is not a repetition of all previously
generated queries.
{generated_sql_query}
Ensure the revised SQL query aligns precisely with the
requirements outlined in the initial question.
Keep the table names as it is. Do not change hyphen
to underscore character
Additionally, please optimize the query for latency
while maintaining correctness and efficiency."""
prompts.append(prompt)
tries += 1
if len(successful_queries) == 0:
return {
"error": "All attempts exhausted.",
"prompts": prompts,
"errors": error_messages
}
else:
df = pd.DataFrame(
[(q["query"], q["dataframe"])
for q in successful_queries], columns=["Query", "Result"]
)
return {
"dataframe": df
}