1_agentic-design-ptn/03_multi-agent/LangGraph/01.1_multi-agent-supervisor.ipynb (463 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "0c370a9a",
"metadata": {},
"source": [
"# Multi-Agent Supervisor\n",
"---\n",
"\n",
"### What is Multi-Agent Supervisor?\n",
"\n",
"A Multi-Agent Supervisor is a control mechanism that oversees and coordinates multiple autonomous agents operating within a system. It ensures that agents work collaboratively and efficiently by managing tasks, resolving conflicts, optimizing resource allocation, and enforcing system constraints. The supervisor can be centralized, decentralized, or distributed, depending on the system architecture. It is commonly used in multi-robot systems, industrial automation, and AI-driven applications to enhance coordination, adaptability, and decision-making.\n",
"\n",
"As the number of agents increases, the branching logic also becomes more complex. The Supervisor agent gathers various specialized agents together and operates them as a single team. The Supervisor agent observes the progress of the team and performs logic such as calling the appropriate agent for each step or terminating the task.\n",
"\n",
"**Reference**\n",
"\n",
"- [Multi-Agent Supervisor Concept](https://langchain-ai.github.io/langgraph/concepts/multi_agent/#supervisor) \n",
"- [LangChain `create_react_agent` built-in function](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b86d4cf1",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"from azure_genai_utils.tracer import get_langchain_api_key, set_langsmith\n",
"\n",
"load_dotenv(override=True)\n",
"\n",
"# If you want to trace your RAG API calls, please set the tracing=True. You need to have a valid Langchain API key.\n",
"langchain_key, has_langchain_key = get_langchain_api_key()\n",
"set_langsmith(\"[RAG Innv Lab] 1_Agentic-Design-Pattern\", tracing=False)\n",
"\n",
"azure_openai_chat_deployment_name = os.getenv(\"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME\")"
]
},
{
"cell_type": "markdown",
"id": "561f72a1",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 1. Test and Construct each module\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "0d4c516f",
"metadata": {},
"source": [
"### Define your LLM\n",
"This hands-on only uses the `gpt-4o-mini`, but you can utilize multiple models in the pipeline."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "68aa3c3e",
"metadata": {},
"outputs": [],
"source": [
"from langchain_openai import AzureChatOpenAI\n",
"\n",
"llm = AzureChatOpenAI(model=azure_openai_chat_deployment_name, temperature=0)"
]
},
{
"cell_type": "markdown",
"id": "7c335a99",
"metadata": {},
"source": [
"### Tools\n",
"\n",
"Before building the entire the graph pipeline, we will test and construct each module separately.\n",
"\n",
"- **Researcher**\n",
"- **Coder**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c8fd5ed",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.tools import BingSearch\n",
"from langchain_experimental.tools import PythonREPLTool\n",
"\n",
"WEB_SEARCH_FORMAT_OUTPUT = True\n",
"\n",
"web_search_tool = BingSearch(\n",
" max_results=3,\n",
" locale=\"en-US\",\n",
" include_news=False,\n",
" include_entity=False,\n",
" format_output=WEB_SEARCH_FORMAT_OUTPUT,\n",
")\n",
"\n",
"python_repl_tool = PythonREPLTool()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "744c2cb3",
"metadata": {},
"outputs": [],
"source": [
"web_search_tool.invoke(\"Where is Seoul?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d49f91db",
"metadata": {},
"outputs": [],
"source": [
"python_repl_tool.invoke(\"print('Hello, World!')\")"
]
},
{
"cell_type": "markdown",
"id": "3f8b177c",
"metadata": {},
"source": [
"### ReAct agent test (Not required. Just for testing)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "887d7892",
"metadata": {},
"outputs": [],
"source": [
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"research_agent = create_react_agent(llm, tools=[web_search_tool])\n",
"research_agent.invoke({\"messages\": \"Where is Seoul?\"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7a9ce4d",
"metadata": {},
"outputs": [],
"source": [
"coder_agent = create_react_agent(llm, tools=[python_repl_tool], prompt=None)\n",
"coder_agent.invoke({\"messages\": [(\"user\", \"print 'Hello, World!'\")]})"
]
},
{
"cell_type": "markdown",
"id": "82f1e2d0",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 2. Define the Graph\n",
"---\n",
"\n",
"### State Definition\n",
"\n",
"- `messages`: Messages to be passed between agents\n",
"- `next`: Next agent to be called"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c01e2a7d",
"metadata": {},
"outputs": [],
"source": [
"import operator\n",
"from typing import Sequence, Annotated\n",
"from typing_extensions import TypedDict\n",
"from langchain_core.messages import BaseMessage\n",
"\n",
"\n",
"class AgentState(TypedDict):\n",
" messages: Annotated[Sequence[BaseMessage], operator.add]\n",
" next: Annotated[str, \"Next agent to be called\"]"
]
},
{
"cell_type": "markdown",
"id": "2501de5e",
"metadata": {},
"source": [
"### Create Agent Supervisor\n",
"\n",
"Create a supervisor agent that manages the agents."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b367fab5",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel\n",
"from typing import Literal\n",
"\n",
"# Member agents\n",
"members = [\"Researcher\", \"Coder\"]\n",
"\n",
"# Our team supervisor is an LLM node. It just picks the next agent to process\n",
"# and decides when the work is completed\n",
"options_for_next = [\"FINISH\"] + members\n",
"\n",
"\n",
"class RouteResponse(BaseModel):\n",
" \"\"\"Worker to route to next. If no workers needed, route to FINISH.\"\"\"\n",
"\n",
" next: Literal[*options_for_next]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1e6a43c",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
"\n",
"\n",
"def supervisor_agent(state: AgentState):\n",
"\n",
" system_prompt = (\n",
" \"You are a supervisor tasked with managing a conversation between the\"\n",
" \" following workers: {members}. Given the following user request,\"\n",
" \" respond with the worker to act next. Each worker will perform a\"\n",
" \" task and respond with their results and status. When finished,\"\n",
" \" respond with FINISH.\"\n",
" )\n",
"\n",
" prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\"system\", system_prompt),\n",
" MessagesPlaceholder(variable_name=\"messages\"),\n",
" (\n",
" \"system\",\n",
" \"Given the conversation above, who should act next? \"\n",
" \"Or should we FINISH? Select one of: {options}\",\n",
" ),\n",
" ]\n",
" ).partial(options=str(options_for_next), members=\", \".join(members))\n",
"\n",
" supervisor_chain = prompt | llm.with_structured_output(RouteResponse)\n",
" return supervisor_chain.invoke(dict(state))"
]
},
{
"cell_type": "markdown",
"id": "74d82e23",
"metadata": {},
"source": [
"Not required, but it is good to test the agent before creating the graph workflow."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d00cc20",
"metadata": {},
"outputs": [],
"source": [
"supervisor_agent({\"messages\": [], \"next\": \"Researcher\"})"
]
},
{
"cell_type": "markdown",
"id": "c1e36a96",
"metadata": {},
"source": [
"### Create Agents\n",
"\n",
"Create agents that perform sub-tasks.\n",
"- `Researcher`: Researches the topic\n",
"- `Coder`: Codes the solution"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd5e5a07",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"from langgraph.prebuilt import create_react_agent\n",
"from langchain_core.messages import HumanMessage\n",
"from langchain_core.prompts import load_prompt\n",
"\n",
"\n",
"# Agent node that invokes the agent with the given state\n",
"def agent_node(state: AgentState, agent, name) -> AgentState:\n",
" agent_response = agent.invoke(state)\n",
" # Return the last message of the agent as a HumanMessage and set the next agent\n",
" return {\n",
" \"messages\": [\n",
" HumanMessage(content=agent_response[\"messages\"][-1].content, name=name)\n",
" ],\n",
" \"next\": \"Supervisor\", # Set the next agent to Supervisor or any other logic\n",
" }\n",
"\n",
"\n",
"# Create Research Agent\n",
"research_agent = create_react_agent(llm, tools=[web_search_tool])\n",
"research_node = functools.partial(agent_node, agent=research_agent, name=\"Researcher\")\n",
"\n",
"# Create Coder Agent\n",
"code_system_prompt = load_prompt(\"../../../prompts/code-system-prompt-kr.yaml\").format()\n",
"coder_agent = create_react_agent(\n",
" llm, tools=[python_repl_tool], prompt=code_system_prompt\n",
")\n",
"coder_node = functools.partial(agent_node, agent=coder_agent, name=\"Coder\")"
]
},
{
"cell_type": "markdown",
"id": "c555f131",
"metadata": {},
"source": [
"### Construct the Graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5029ab17",
"metadata": {},
"outputs": [],
"source": [
"from langgraph.graph import END, StateGraph, START\n",
"from langgraph.checkpoint.memory import MemorySaver\n",
"\n",
"\n",
"def get_next(state):\n",
" return state[\"next\"]\n",
"\n",
"\n",
"workflow = StateGraph(AgentState)\n",
"\n",
"# Node definition\n",
"workflow.add_node(\"Researcher\", research_node)\n",
"workflow.add_node(\"Coder\", coder_node)\n",
"workflow.add_node(\"Supervisor\", supervisor_agent)\n",
"\n",
"# Edge definition\n",
"workflow.add_edge(START, \"Supervisor\")\n",
"\n",
"conditional_map = {k: k for k in members}\n",
"conditional_map[\"FINISH\"] = END\n",
"workflow.add_conditional_edges(\"Supervisor\", get_next, conditional_map)\n",
"\n",
"for member in members:\n",
" workflow.add_edge(member, \"Supervisor\")\n",
"\n",
"# Compile the workflow\n",
"app = workflow.compile(checkpointer=MemorySaver())"
]
},
{
"cell_type": "markdown",
"id": "c597bbe1",
"metadata": {},
"source": [
"### Visualize the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d0a8e3c",
"metadata": {},
"outputs": [],
"source": [
"from azure_genai_utils.graphs import visualize_langgraph\n",
"\n",
"visualize_langgraph(app, xray=True)"
]
},
{
"cell_type": "markdown",
"id": "0c2bf2b1",
"metadata": {},
"source": [
"<br>\n",
"\n",
"## 🧪 Step 3. Execute the Graph\n",
"---\n",
"\n",
"### Execute the graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "215d94b4",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.runnables import RunnableConfig\n",
"from azure_genai_utils.messages import stream_graph, invoke_graph, random_uuid\n",
"\n",
"config = RunnableConfig(recursion_limit=10, configurable={\"thread_id\": random_uuid()})\n",
"inputs = {\n",
" \"messages\": [\n",
" HumanMessage(\n",
" content=\"Visualize Microsoft's stock price over the past 5 years on a graph.\"\n",
" )\n",
" ],\n",
"}\n",
"\n",
"invoke_graph(app, inputs, config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11ffb895",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv_agent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}