orchestration/strategies/chat_with_fabric_strategy.py (134 lines of code) (raw):

import os import re from typing import List, Optional, Annotated from pydantic import BaseModel from autogen_agentchat.agents import AssistantAgent from autogen_core.tools import FunctionTool from autogen_agentchat.messages import TextMessage from tools import ExecuteQueryResult from .base_agent_strategy import BaseAgentStrategy from ..constants import Strategy from tools import ( get_time, get_today_date, queries_retrieval, get_all_datasources_info, tables_retrieval, measures_retrieval, get_all_tables_info, get_schema_info, execute_dax_query, validate_sql_query, execute_sql_query, ) ## Agent Response Types class ChatGroupResponse(BaseModel): answer: str reasoning: str class ChatGroupTextOnlyResponse(BaseModel): answer: str class DataSource(BaseModel): name: str description: str class TriageAgentResponse(BaseModel): answer: str datasources: Optional[List[DataSource]] # Agents Strategy Class class ChatWithFabricStrategy(BaseAgentStrategy): def __init__(self): # Initialize the strategy type super().__init__() self.strategy_type = Strategy.CHAT_WITH_FABRIC async def create_agents(self, history, client_principal=None, access_token=None, text_only=False, optimize_for_audio=False): # Response configuration self.text_only=text_only self.optimize_for_audio=optimize_for_audio # Model Context shared_context = await self._get_model_context(history) # Wrapper Functions for Tools tables_retrieval_tool = FunctionTool( tables_retrieval, description="Retrieve a all tables that are relevant to the user question." ) get_all_datasources_info_tool = FunctionTool( get_all_datasources_info, description="Retrieve a list of all datasources." ) get_all_tables_info_tool = FunctionTool( get_all_tables_info, description="Retrieve a list of tables filtering by the given datasource." ) get_schema_info_tool = FunctionTool( get_schema_info, description="Retrieve information about tables and columns from the data dictionary." ) measures_retrieval_tool = FunctionTool( measures_retrieval, description="Retrieve a list of measures filtering by the given datasource." ) queries_retrieval_tool = FunctionTool( queries_retrieval, description="Retrieve QueriesRetrievalResult a list of similar QueryItem containing a question, the correspondent query and reasoning." ) async def execute_dax_query_wrapper( datasource: Annotated[str, "Target datasource"], query: Annotated[str, "DAX Query"] ) -> ExecuteQueryResult: return await execute_dax_query(datasource, query, access_token) execute_dax_query_tool = FunctionTool( execute_dax_query_wrapper, name="execute_dax_query", description="Execute a DAX query and return the results." ) validate_sql_query_tool = FunctionTool( validate_sql_query, description="Validate the syntax of an SQL query." ) execute_sql_query_tool = FunctionTool( execute_sql_query, description="Execute a SQL query against the datasource provided by the Triage Agent and return the results." ) # Agents ## Triage Agent triage_prompt = await self._read_prompt("triage_agent") triage_agent = AssistantAgent( name="triage_agent", system_message=triage_prompt, model_client=self._get_model_client(), tools=[get_all_datasources_info_tool, tables_retrieval_tool, get_today_date, get_time], reflect_on_tool_use=True, model_context=shared_context ) ## DAX Query Agent dax_query_prompt = await self._read_prompt("dax_query_agent") dax_query_agent = AssistantAgent( name="dax_query_agent", system_message=dax_query_prompt, model_client=self._get_model_client(), tools=[queries_retrieval_tool, measures_retrieval_tool, get_schema_info_tool, execute_dax_query_tool, get_today_date, get_time], reflect_on_tool_use=True, model_context=shared_context ) ## SQL Query Agent sql_query_prompt = await self._read_prompt("sql_query_agent") sql_query_agent = AssistantAgent( name="sql_query_agent", system_message=sql_query_prompt, model_client=self._get_model_client(), tools=[queries_retrieval_tool, get_schema_info_tool, validate_sql_query_tool, execute_sql_query_tool, get_today_date, get_time], reflect_on_tool_use=True, model_context=shared_context ) # Society Of Mind Agent (Query Agents) # inner_termination = TextMentionTermination("QUESTION_ANSWERED") # response_prompt = "Copy the content of the last agent message exactly, without mentioning any of the intermediate discussion." # inner_team = RoundRobinGroupChat([dax_query_agent, sql_query_agent], termination_condition=inner_termination) # inner_team = SelectorGroupChat( # participants=[dax_query_agent, sql_query_agent], # model_client=self._get_model_client(), # termination_condition=inner_termination, # max_turns=30 # ) # query_agents = SocietyOfMindAgent("query_agents", team=inner_team, response_prompt=response_prompt, model_client=self._get_model_client()) ## Chat Closure Agent if optimize_for_audio: prompt_name = "chat_closure_audio" chat_group_response_type = ChatGroupTextOnlyResponse else: prompt_name = "chat_closure" chat_group_response_type = ChatGroupResponse chat_closure = AssistantAgent( name="chat_closure", system_message=await self._read_prompt(prompt_name), model_client=self._get_model_client(response_format=chat_group_response_type) ) # Group Chat Configuration self.max_rounds = int(os.getenv('MAX_ROUNDS', 40)) def custom_selector_func(messages): """ Selects the next agent based on the last message. """ last_msg = messages[-1] if last_msg.source == "user": return "triage_agent" if isinstance(last_msg, TextMessage) and re.search(r"QUESTION_ANSWERED\.?$", last_msg.content.strip()): return "chat_closure" def is_datasource_selected(source, keyword): return last_msg.source == "triage_agent" and \ "DATASOURCE_SELECTED" in last_msg.content.strip() and \ keyword in last_msg.content.strip() agent_mapping = { "sql_query_agent": "sql_endpoint", "dax_query_agent": "semantic_model" } for agent, keyword in agent_mapping.items(): if last_msg.source == agent or is_datasource_selected(last_msg.source, keyword): return agent return "triage_agent" self.selector_func = custom_selector_func self.agents = [triage_agent, dax_query_agent, sql_query_agent, chat_closure] # self.agents = [triage_agent, query_agents, chat_closure] return self._get_agents_configuration()