components/llm_service/notebooks/RoutingAgent.ipynb (260 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "61f9ba3c-8703-4b16-8af2-9d5d4d327b2d",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"sys.path.append(\"../../common/src\")\n",
"sys.path.append(\"../src\")\n",
"PROJECT_ID = os.getenv(\"PROJECT_ID\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8fa76111-f1f8-4789-b202-4ca10a69abfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO: [config/config.py:55 - <module>()] Namespace File not found, setting job namespace as default\n",
"INFO: [config/config.py:105 - <module>()] ENABLE_GOOGLE_LLM = True\n",
"INFO: [config/config.py:106 - <module>()] ENABLE_OPENAI_LLM = True\n",
"INFO: [config/config.py:107 - <module>()] ENABLE_COHERE_LLM = True\n",
"INFO: [config/config.py:108 - <module>()] ENABLE_GOOGLE_MODEL_GARDEN = True\n",
"INFO: [config/config.py:109 - <module>()] ENABLE_TRUSS_LLAMA2 = True\n",
"INFO: [config/vector_store_config.py:40 - <module>()] Default vector store = [matching_engine]\n",
"INFO: [config/vector_store_config.py:49 - <module>()] PG_HOST = [127.0.0.1]\n",
"INFO: [config/vector_store_config.py:50 - <module>()] PG_DBNAME = [pgvector]\n",
"ERROR: [config/vector_store_config.py:77 - <module>()] Cannot connect to pgvector instance at 127.0.0.1: (psycopg2.OperationalError) connection to server at \"127.0.0.1\", port 5432 failed: FATAL: database \"pgvector\" does not exist\n",
"\n",
"(Background on this error at: https://sqlalche.me/e/14/e3q8)\n",
"INFO: [utils/text_helper.py:43 - <module>()] using default spacy model\n"
]
}
],
"source": [
"from config.utils import set_agent_config\n",
"from common.models import (User, UserChat, QueryResult,\n",
" QueryEngine, UserPlan, PlanStep)\n",
"from common.models.llm import CHAT_AI\n",
"from common.models.agent import AgentCapability\n",
"from common.utils.logging_handler import Logger\n",
"from config import (get_model_config, PROVIDER_LANGCHAIN,\n",
" OPENAI_LLM_TYPE_GPT4,\n",
" VERTEX_LLM_TYPE_BISON_CHAT_LANGCHAIN,\n",
" OPENAI_LLM_TYPE_GPT4_LATEST)\n",
"\n",
"from services.agents.routing_agent import run_intent, run_routing_agent\n",
"from services.agents.agents import BaseAgent\n",
"from langchain.agents import AgentExecutor\n",
"from services.agents.routing_agent import get_dispatch_prompt\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ed5fdda3-f67e-4295-af3d-a8ad5296e522",
"metadata": {},
"outputs": [],
"source": [
"AGENT_CONFIG = {\n",
" \"Agents\":\n",
" {\n",
" \"Router\": {\n",
" \"llm_type\": OPENAI_LLM_TYPE_GPT4,\n",
" \"agent_type\": \"langchain_Conversational\",\n",
" \"agent_class\": \"RoutingAgent\",\n",
" \"tools\": \"\"\n",
" },\n",
" \"Chat\": {\n",
" \"llm_type\": OPENAI_LLM_TYPE_GPT4,\n",
" \"agent_type\": \"langchain_Conversational\",\n",
" \"agent_class\": \"ChatAgent\",\n",
" \"tools\": \"search_tool,query_tool\",\n",
" \"query_engines\": \"ALL\"\n",
" },\n",
" \"Task\": {\n",
" \"llm_type\": OPENAI_LLM_TYPE_GPT4_LATEST,\n",
" \"agent_type\": \"langchain_StructuredChatAgent\",\n",
" \"agent_class\": \"TaskAgent\",\n",
" \"tools\": \"ALL\"\n",
" },\n",
" \"Plan\": {\n",
" \"llm_type\": OPENAI_LLM_TYPE_GPT4_LATEST,\n",
" \"agent_type\": \"langchain_ZeroShot\",\n",
" \"agent_class\": \"PlanAgent\",\n",
" \"query_engines\": \"ALL\",\n",
" \"tools\": \"ALL\"\n",
" }\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1c412498-b192-4795-93c2-f59cc4d2e9a7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def run_intent(agent_name):\n",
" llm_service_agent = BaseAgent.get_llm_service_agent(agent_name)\n",
"\n",
" # load corresponding langchain agent and instantiate agent_executor\n",
" langchain_agent = llm_service_agent.load_langchain_agent()\n",
" intent_agent_tools = llm_service_agent.get_tools()\n",
" print(f\"Routing agent tools [{intent_agent_tools}]\")\n",
"\n",
" agent_executor = AgentExecutor.from_agent_and_tools(\n",
" agent=langchain_agent, tools=intent_agent_tools)\n",
"\n",
" # get dispatch prompt\n",
" dispatch_prompt = get_dispatch_prompt(llm_service_agent)\n",
"\n",
" agent_inputs = {\n",
" \"input\": dispatch_prompt + prompt,\n",
" \"chat_history\": []\n",
" }\n",
"\n",
" Logger.info(\"Running agent executor to get best matched route.... \")\n",
" output, agent_logs = await agent_executor_arun_with_logs(\n",
" agent_executor, agent_inputs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe4bcb17-8c43-4184-ae2c-150c9bc2d7e4",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "11bc3c44-5699-4fae-8d1f-a089a17c0b09",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c9b0a53-5e82-4c04-93ac-50b7cebdaff5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad202d48-a1ce-486b-ae61-f21205644c3d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2072b76b-877c-43f0-8d52-50ad75c92622",
"metadata": {},
"outputs": [],
"source": [
"def parse_plan_step(text:str) -> dict:\n",
" step_regex = re.compile(\n",
" r\"\\d+\\.\\s.*\\[(.*)\\]\\s?(.*)\", re.DOTALL)\n",
" matches = step_regex.findall(text)\n",
" return matches\n",
"\n",
"route, detail = parse_plan_step(routes[0])[0]\n",
"route"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfbb59c5-dcc3-4636-bf2c-88534f5ba4f6",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"How do I apply for a driver's license?\"\n",
"agent_inputs = {\n",
" \"input\": prompt,\n",
" \"chat_history\": []\n",
" }\n",
"output = agent_executor.run(agent_inputs)\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9edfff6-d39c-4895-a0b3-e6c7d64464f3",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"What's the latest news in Jordan?\"\n",
"agent_inputs = {\n",
" \"input\": prompt,\n",
" \"chat_history\": []\n",
" }\n",
"output = agent_executor.run(agent_inputs)\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7019d7db-e4a3-47c5-bbd1-9c86e587b635",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"Compose and send an email to all the medicaid applicants that are missing income verification asking them to provide a pay stub from their employers\"\n",
"agent_inputs = {\n",
" \"input\": prompt,\n",
" \"chat_history\": []\n",
" }\n",
"output = agent_executor.run(agent_inputs)\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16d2212e-d5d9-4350-924f-f871b3725cc4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}