agents/ValidateSQLAgent.py (22 lines of code) (raw):

import json from abc import ABC from .core import Agent from utilities import PROMPTS, format_prompt class ValidateSQLAgent(Agent, ABC): """ An agent that validates the syntax and semantic correctness of SQL queries. This agent leverages a language model (currently Gemini) to analyze a given SQL query against a provided database schema. It assesses whether the query is valid according to a set of predefined guidelines and generates a JSON response indicating the validity status and any potential errors. Attributes: agentType (str): Indicates the type of agent, fixed as "ValidateSQLAgent". Methods: check(user_question, tables_schema, columns_schema, generated_sql) -> dict: Determines the validity of an SQL query and identifies potential errors. Args: user_question (str): The original question posed by the user (used for context). tables_schema (str): A description of the database tables and their relationships. columns_schema (str): Detailed descriptions of the columns within the tables. generated_sql (str): The SQL query to be validated. Returns: dict: A JSON-formatted dictionary with the following keys: - "valid": A boolean value indicating whether the query is valid or not. - "errors": A string describing any errors found in the query (empty if valid). """ agentType: str = "ValidateSQLAgent" def check(self,source_type, user_question, tables_schema, columns_schema, generated_sql): context_prompt = PROMPTS['validatesql'] context_prompt = format_prompt(context_prompt, source_type = source_type, user_question = user_question, tables_schema = tables_schema, columns_schema = columns_schema, generated_sql=generated_sql) # print(f"Prompt to Validate SQL after formatting: \n{context_prompt}") if "gemini" in self.model_id: context_query = self.model.generate_content(context_prompt, stream=False) generated_sql = str(context_query.candidates[0].text) else: context_query = self.model.predict(context_prompt, max_output_tokens = 8000, temperature=0) generated_sql = str(context_query.candidates[0]) json_syntax_result = json.loads(str(generated_sql).replace("```json","").replace("```","")) # print('\n SQL Syntax Validity:' + str(json_syntax_result['valid'])) # print('\n SQL Syntax Error Description:' +str(json_syntax_result['errors']) + '\n') return json_syntax_result