in nl2sql_src/nl2sql_generic.py [0:0]
def generate_sql_with_join(self,
dataset,
table_1_name,
table_2_name,
question,
example_table1,
example_table2,
sample_question=None,
sample_sql=None,
one_shot=False,
join_gen="STANDARD"):
gen_join_sql = ""
match join_gen:
case 'STANDARD':
if not one_shot:
# Zero-shot Join query generation
join_prompt = self.get_join_prompt(dataset,
table_1_name,
table_2_name,
question)
gen_join_sql = self.invoke_llm(join_prompt)
else:
# One-shot Join query generation
join_prompt_one_shot = self.get_join_prompt(
dataset,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
gen_join_sql = self.invoke_llm(
join_prompt_one_shot
)
case 'MULTI_TURN':
table_1_name, table_2_name = \
self.multi_turn_table_filter(
table_1_name=example_table1,
table_2_name=example_table2,
sample_question=sample_question,
sample_sql=sample_sql,
question=question
)
# One-shot Join query generation
join_prompt_one_shot = self.get_join_prompt(
data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
gen_join_sql = self.invoke_llm(
join_prompt_one_shot
)
case 'SELF_CORRECT':
join_prompt_one_shot = self.get_join_prompt(
data_set,
table_1_name,
table_2_name,
question,
sample_question,
sample_sql,
one_shot=True
)
# Self-Correction Approach
responses = self.gen_and_exec_and_self_correct_sql(
join_prompt_one_shot
)
gen_join_sql = responses[0]['query']
return gen_join_sql