def get_sql()

in dbai_src/dbai.py [0:0]


    def get_sql(self, question):
        """
        For given question, returns the genrated SQL, result and description
        """
        chat = self.agent.start_chat()
        prompt = question + \
            f"\n The dataset_id is {self.dataset_id}" + \
            self.system_prompt

        response = chat.send_message(prompt)
        response = response.candidates[0].content.parts[0]
        intermediate_steps = []

        function_calling_in_process = True
        while function_calling_in_process:
            try:
                function_name, params = response.function_call.name, {}
                for key, value in response.function_call.args.items():
                    params[key] = value

                api_response = ''
                if function_name == "list_tables":
                    api_response = self.api_list_tables()

                if function_name == "get_table_metadata":
                    api_response = self.api_get_table_metadata(
                        params["table_id"]
                        )

                if function_name == "sql_query":
                    api_response = self.execute_sql_query(params["query"])

                response = chat.send_message(
                    Part.from_function_response(
                        name=function_name,
                        response={
                            "content": api_response,
                        },
                    ),
                )
                response = response.candidates[0].content.parts[0]
                intermediate_steps.append({
                    'function_name': function_name,
                    'function_params': params,
                    'API_response': api_response,
                    'response': response
                })

            except AttributeError:
                function_calling_in_process = False

        generated_sql, sql_output = '', ''
        for i in intermediate_steps[::-1]:
            if i['function_name'] == 'sql_query':
                generated_sql = i['function_params']['query']
                sql_output = i['API_response']

        return NL2SQLResp(response.text, generated_sql, sql_output)