def generate_sql_with_join()

in nl2sql_src/nl2sql_generic.py [0:0]


    def generate_sql_with_join(self,
                               dataset,
                               table_1_name,
                               table_2_name,
                               question,
                               example_table1,
                               example_table2,
                               sample_question=None,
                               sample_sql=None,
                               one_shot=False,
                               join_gen="STANDARD"):
        gen_join_sql = ""
        match join_gen:
            case 'STANDARD':
                if not one_shot:
                    # Zero-shot Join query generation
                    join_prompt = self.get_join_prompt(dataset,
                                                       table_1_name,
                                                       table_2_name,
                                                       question)
                    gen_join_sql = self.invoke_llm(join_prompt)
                else:
                    # One-shot Join query generation
                    join_prompt_one_shot = self.get_join_prompt(
                        dataset,
                        table_1_name,
                        table_2_name,
                        question,
                        sample_question,
                        sample_sql,
                        one_shot=True
                    )
                    gen_join_sql = self.invoke_llm(
                        join_prompt_one_shot
                    )

            case 'MULTI_TURN':
                table_1_name, table_2_name = \
                    self.multi_turn_table_filter(
                        table_1_name=example_table1,
                        table_2_name=example_table2,
                        sample_question=sample_question,
                        sample_sql=sample_sql,
                        question=question
                        )
                # One-shot Join query generation
                join_prompt_one_shot = self.get_join_prompt(
                    data_set,
                    table_1_name,
                    table_2_name,
                    question,
                    sample_question,
                    sample_sql,
                    one_shot=True
                )
                gen_join_sql = self.invoke_llm(
                    join_prompt_one_shot
                )

            case 'SELF_CORRECT':
                join_prompt_one_shot = self.get_join_prompt(
                    data_set,
                    table_1_name,
                    table_2_name,
                    question,
                    sample_question,
                    sample_sql,
                    one_shot=True
                )
                # Self-Correction Approach
                responses = self.gen_and_exec_and_self_correct_sql(
                    join_prompt_one_shot
                )
                gen_join_sql = responses[0]['query']

        return gen_join_sql