def gen_and_exec_and_self_correct_sql()

in nl2sql_src/nl2sql_generic.py [0:0]


    def gen_and_exec_and_self_correct_sql(self,
                                          prompt,
                                          genai_model_name="GeminiPro",
                                          max_tries=5,
                                          return_all=False):
        """
            Wrapper function for Standard, Multi-turn and Self Correct
            approach of SQL generation
        """
        tries = 0
        error_messages = []
        prompts = [prompt]
        successful_queries = []
        TEMPERATURE = 0.3
        MAX_OUTPUT_TOKENS = 8192

        MODEL_NAME = 'codechat-bison-32k'
        code_gen_model = CodeChatModel.from_pretrained(MODEL_NAME)

        model = GenerativeModel("gemini-1.0-pro")

        if genai_model_name == "GeminiPro":
            chat_session = model.start_chat()
        else:
            chat_session = CodeChatSession(model=code_gen_model,
                                           temperature=TEMPERATURE,
                                           max_output_tokens=MAX_OUTPUT_TOKENS
                                           )

        while tries < max_tries:
            try:
                if genai_model_name == "GeminiPro":
                    response = chat_session.send_message(prompt)
                else:
                    response = chat_session.send_message(
                        prompt,
                        temperature=TEMPERATURE,
                        max_output_tokens=MAX_OUTPUT_TOKENS
                    )

                generated_sql_query = response.text
                generated_sql_query = '\n'.join(
                    generated_sql_query.split('\n')[1:-1]
                )

                generated_sql_query = self.case_handler_transform(
                    generated_sql_query
                )
                generated_sql_query = self.add_dataset_to_query(
                    generated_sql_query
                )
                df = client.query(generated_sql_query).to_dataframe()
                successful_queries.append({
                    "query": generated_sql_query,
                    "dataframe": df
                })
                if len(successful_queries) > 1:
                    prompt = f"""Modify the last successful SQL query by
                        making changes to it and optimizing it for latency.
                        ENSURE that the NEW QUERY is DIFFERENT from the
                        previous one while prioritizing faster execution.
                        Reference the tables only from the above given
                        project and dataset
                        The last successful query was:
                        {successful_queries[-1]["query"]}"""

            except Exception as e:
                msg = str(e)
                error_messages.append(msg)
                prompt = f"""Encountered an error: {msg}.
                    To address this, please generate an alternative SQL
                    query response that avoids this specific error.
                    Follow the instructions mentioned above to
                    remediate the error.

                    Modify the below SQL query to resolve the issue and
                    ensure it is not a repetition of all previously
                    generated queries.
                    {generated_sql_query}

                    Ensure the revised SQL query aligns precisely with the
                    requirements outlined in the initial question.
                    Keep the table names as it is. Do not change hyphen
                    to underscore character
                    Additionally, please optimize the query for latency
                    while maintaining correctness and efficiency."""
                prompts.append(prompt)

            tries += 1

        if len(successful_queries) == 0:
            return {
                "error": "All attempts exhausted.",
                "prompts": prompts,
                "errors": error_messages
            }
        else:
            df = pd.DataFrame(
                [(q["query"], q["dataframe"])
                 for q in successful_queries], columns=["Query", "Result"]
            )
            return {
                "dataframe": df
            }