def generate_sql_few_shot()

in nl2sql_src/nl2sql_generic.py [0:0]


    def generate_sql_few_shot(self,
                              question,
                              table_name=None,
                              logger_file="log.txt"):
        """
        Main function which converts NL to SQL using few shot prompting
        """

        # step-1 table selection
        try:
            if not table_name:
                if len(self.metadata_json.keys()) > 1:
                    table_list = self.table_filter(question)
                    table_name = table_list[0]
                else:
                    table_name = list(self.metadata_json.keys())[0]
            table_json = self.metadata_json[table_name]
            columns_json = table_json["Columns"]
            columns_info = ""
            for column_name in columns_json:
                column = columns_json[column_name]
                column_info = f"""{column["Name"]} \
                    ({column["Type"]}) : {column["Description"]}.\
                    {column["Examples"]}\n"""
                columns_info = columns_info + column_info

            # few_shot_json = self.pge.search_matching_queries(question)
            embed = Nl2Sql_embed()
            few_shot_json = embed.search_matching_queries(question)
            logger.info(f"Few sjot examples : {few_shot_json}")
            few_shot_examples = ""

            for item in few_shot_json:
                example_string = f"Question: {item['question']}"
                few_shot_examples += example_string + "\n"
                example_string = f"SQL : {item['sql']} "
                few_shot_examples += example_string + "\n\n"

            sql_prompt = Sql_Generation_prompt_few_shot.format(
                table_name=table_json["Name"],
                table_description=table_json["Description"],
                columns_info=columns_info,
                few_shot_examples=few_shot_examples,
                question=question)
            response = self.llm.invoke(sql_prompt)
            sql_query = response.replace('sql', '').replace('```', '')

            # sql_query = self.case_handler_transform(sql_query)

            sql_query = self.add_dataset_to_query(sql_query)
            with open(logger_file, 'a', encoding="utf-8") as f:
                f.write(f">>\nModel:{self.model_name} \n\nQuestion: {question}\
                         \n\nPrompt:{sql_prompt} \n\nSql_query:{sql_query}<\n")

            if sql_query.strip().startswith("Response:"):
                sql_query = sql_query.split(":")[1].strip()
            return sql_query
        except Exception as exc:
            raise Exception(traceback.print_exc()) from exc