def batch_run()

in nl2sql_src/nl2sql_generic.py [0:0]


    def batch_run(self,
                  test_file_name,
                  output_file_name,
                  execute_query=False,
                  result2nl=False,
                  insight=True,
                  logger_file="log.txt"):
        """
        This function procesess a batch of questions from a test file,
        generate SQL queries, and evaluate their accuracy.
        It reads questions from a CSV file, generates SQL queries using the
        `gen_sql` function,
        evaluates the accuracy of the generated queries using the `auto_verify`
        function,
        and optionally converts SQL queries to natural language
        using the `sql2result` and `result2nl` functions.
        The results are stored in a DataFrame and saved to a CSV file in the
        'output' directory,
        with a timestamped filename.

        Parameters:
        - test_file_name (str):
        The name of the CSV file containing test questions and ground truth
        SQL queries.

        - sql2nl (bool, optional):
        Flag to convert SQL queries to natural language. Defaults to False.

        Returns:
        pandas.DataFrame: A DataFrame containing question, ground truth SQL,
        LLM-generated SQL, LLM rating, SQL execution result, and NL response.
        """
        try:
            questions = pd.read_csv(test_file_name)

            out = []
            columns = ['question',
                       'ground_truth',
                       'llm_response',
                       'llm_rating'
                       ]
            if execute_query:
                columns.append('sql_result')
            if result2nl:
                columns.append('nl_response')
            for _, row in questions.iterrows():
                table_name = None
                if row["table"].strip():
                    table_name = row["table"]
                question = row["question"]
                # print(question)
                sql_gen = self.generate_sql(question,
                                            table_name=table_name,
                                            logger_file=logger_file
                                            )
                # print(sql_gen)
                rating = self.auto_verify(question,
                                          row["ground_truth_sql"], sql_gen
                                          )
                row_result = [question,
                              row["ground_truth_sql"], sql_gen, rating]
                if execute_query:
                    result = self.execute_query(sql_gen)
                    # print(result)
                    row_result.append(result)
                if execute_query and result2nl:
                    nl = self.result2nl(result, question, insight=insight)
                    row_result.append(nl)
                out.append(row_result)
                # print("\n\n")

            df = pd.DataFrame(out, columns=columns)
            df.to_csv(output_file_name, index=False)
            return df
        except Exception as exc:
            raise Exception(traceback.print_exc()) from exc