orchestration/strategies/nl2sql_fewshot_strategy.py (60 lines of code) (raw):

import os from pydantic import BaseModel from autogen_agentchat.agents import AssistantAgent from autogen_core.tools import FunctionTool from .nl2sql_base_agent_strategy import NL2SQLBaseStrategy from ..constants import Strategy from tools import ( get_time, get_today_date, queries_retrieval, get_all_datasources_info, get_all_tables_info, get_schema_info, validate_sql_query, execute_sql_query, ) # Agents Strategy Class class NL2SQLFewshotStrategy(NL2SQLBaseStrategy): def __init__(self): self.strategy_type = Strategy.NL2SQL_FEWSHOT super().__init__() async def create_agents(self, history, client_principal=None, access_token=None, output_mode=None, output_format=None): """ Creates agents and registers functions for the NL2SQL single agent scenario. """ # Model Context shared_context = await self._get_model_context(history) # Wrapper Functions for Tools get_all_datasources_info_tool = FunctionTool( get_all_datasources_info, description="Retrieve a list of all datasources." ) get_all_tables_info_tool = FunctionTool( get_all_tables_info, description="Retrieve a list of tables filtering by the given datasource." ) get_schema_info_tool = FunctionTool( get_schema_info, description="Retrieve information about tables and columns from the data dictionary." ) queries_retrieval_tool = FunctionTool( queries_retrieval, description="Retrieve QueriesRetrievalResult a list of similar QueryItem containing a question, the correspondent query and reasoning." ) validate_sql_query_tool = FunctionTool( validate_sql_query, description="Validate the syntax of an SQL query." ) execute_sql_query_tool = FunctionTool( execute_sql_query, description="Execute an SQL query and return the results." ) # Agents ## Assistant Agent assistant_prompt = await self._read_prompt("nl2sql_assistant") assistant = AssistantAgent( name="assistant", system_message=assistant_prompt, model_client=self._get_model_client(), tools=[get_all_datasources_info_tool, get_schema_info_tool, validate_sql_query_tool, queries_retrieval_tool, get_all_tables_info_tool, execute_sql_query_tool, get_today_date, get_time], reflect_on_tool_use=True, model_context=shared_context ) ## Chat Closure Agent chat_closure = await self._create_chat_closure_agent(output_format, output_mode) # Group Chat Configuration self.max_rounds = int(os.getenv('MAX_ROUNDS', 20)) def custom_selector_func(messages): """ Selects the next agent based on the source of the last message. Transition Rules: user -> assistant assistant -> None (SelectorGroupChat will handle transition) """ last_msg = messages[-1] if last_msg.source == "user": return "assistant" else: return None self.selector_func = custom_selector_func self.agents = [assistant, chat_closure] return self._get_agents_configuration()