components/llm_service/notebooks/RoutingAgent.ipynb (260 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "61f9ba3c-8703-4b16-8af2-9d5d4d327b2d", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "sys.path.append(\"../../common/src\")\n", "sys.path.append(\"../src\")\n", "PROJECT_ID = os.getenv(\"PROJECT_ID\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "8fa76111-f1f8-4789-b202-4ca10a69abfd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO: [config/config.py:55 - <module>()] Namespace File not found, setting job namespace as default\n", "INFO: [config/config.py:105 - <module>()] ENABLE_GOOGLE_LLM = True\n", "INFO: [config/config.py:106 - <module>()] ENABLE_OPENAI_LLM = True\n", "INFO: [config/config.py:107 - <module>()] ENABLE_COHERE_LLM = True\n", "INFO: [config/config.py:108 - <module>()] ENABLE_GOOGLE_MODEL_GARDEN = True\n", "INFO: [config/config.py:109 - <module>()] ENABLE_TRUSS_LLAMA2 = True\n", "INFO: [config/vector_store_config.py:40 - <module>()] Default vector store = [matching_engine]\n", "INFO: [config/vector_store_config.py:49 - <module>()] PG_HOST = [127.0.0.1]\n", "INFO: [config/vector_store_config.py:50 - <module>()] PG_DBNAME = [pgvector]\n", "ERROR: [config/vector_store_config.py:77 - <module>()] Cannot connect to pgvector instance at 127.0.0.1: (psycopg2.OperationalError) connection to server at \"127.0.0.1\", port 5432 failed: FATAL: database \"pgvector\" does not exist\n", "\n", "(Background on this error at: https://sqlalche.me/e/14/e3q8)\n", "INFO: [utils/text_helper.py:43 - <module>()] using default spacy model\n" ] } ], "source": [ "from config.utils import set_agent_config\n", "from common.models import (User, UserChat, QueryResult,\n", " QueryEngine, UserPlan, PlanStep)\n", "from common.models.llm import CHAT_AI\n", "from common.models.agent import AgentCapability\n", "from common.utils.logging_handler import Logger\n", "from config import (get_model_config, PROVIDER_LANGCHAIN,\n", " OPENAI_LLM_TYPE_GPT4,\n", " VERTEX_LLM_TYPE_BISON_CHAT_LANGCHAIN,\n", " OPENAI_LLM_TYPE_GPT4_LATEST)\n", "\n", "from services.agents.routing_agent import run_intent, run_routing_agent\n", "from services.agents.agents import BaseAgent\n", "from langchain.agents import AgentExecutor\n", "from services.agents.routing_agent import get_dispatch_prompt\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed5fdda3-f67e-4295-af3d-a8ad5296e522", "metadata": {}, "outputs": [], "source": [ "AGENT_CONFIG = {\n", " \"Agents\":\n", " {\n", " \"Router\": {\n", " \"llm_type\": OPENAI_LLM_TYPE_GPT4,\n", " \"agent_type\": \"langchain_Conversational\",\n", " \"agent_class\": \"RoutingAgent\",\n", " \"tools\": \"\"\n", " },\n", " \"Chat\": {\n", " \"llm_type\": OPENAI_LLM_TYPE_GPT4,\n", " \"agent_type\": \"langchain_Conversational\",\n", " \"agent_class\": \"ChatAgent\",\n", " \"tools\": \"search_tool,query_tool\",\n", " \"query_engines\": \"ALL\"\n", " },\n", " \"Task\": {\n", " \"llm_type\": OPENAI_LLM_TYPE_GPT4_LATEST,\n", " \"agent_type\": \"langchain_StructuredChatAgent\",\n", " \"agent_class\": \"TaskAgent\",\n", " \"tools\": \"ALL\"\n", " },\n", " \"Plan\": {\n", " \"llm_type\": OPENAI_LLM_TYPE_GPT4_LATEST,\n", " \"agent_type\": \"langchain_ZeroShot\",\n", " \"agent_class\": \"PlanAgent\",\n", " \"query_engines\": \"ALL\",\n", " \"tools\": \"ALL\"\n", " }\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "1c412498-b192-4795-93c2-f59cc4d2e9a7", "metadata": {}, "outputs": [], "source": [ "\n", "def run_intent(agent_name):\n", " llm_service_agent = BaseAgent.get_llm_service_agent(agent_name)\n", "\n", " # load corresponding langchain agent and instantiate agent_executor\n", " langchain_agent = llm_service_agent.load_langchain_agent()\n", " intent_agent_tools = llm_service_agent.get_tools()\n", " print(f\"Routing agent tools [{intent_agent_tools}]\")\n", "\n", " agent_executor = AgentExecutor.from_agent_and_tools(\n", " agent=langchain_agent, tools=intent_agent_tools)\n", "\n", " # get dispatch prompt\n", " dispatch_prompt = get_dispatch_prompt(llm_service_agent)\n", "\n", " agent_inputs = {\n", " \"input\": dispatch_prompt + prompt,\n", " \"chat_history\": []\n", " }\n", "\n", " Logger.info(\"Running agent executor to get best matched route.... \")\n", " output, agent_logs = await agent_executor_arun_with_logs(\n", " agent_executor, agent_inputs)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fe4bcb17-8c43-4184-ae2c-150c9bc2d7e4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "11bc3c44-5699-4fae-8d1f-a089a17c0b09", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7c9b0a53-5e82-4c04-93ac-50b7cebdaff5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ad202d48-a1ce-486b-ae61-f21205644c3d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "2072b76b-877c-43f0-8d52-50ad75c92622", "metadata": {}, "outputs": [], "source": [ "def parse_plan_step(text:str) -> dict:\n", " step_regex = re.compile(\n", " r\"\\d+\\.\\s.*\\[(.*)\\]\\s?(.*)\", re.DOTALL)\n", " matches = step_regex.findall(text)\n", " return matches\n", "\n", "route, detail = parse_plan_step(routes[0])[0]\n", "route" ] }, { "cell_type": "code", "execution_count": null, "id": "dfbb59c5-dcc3-4636-bf2c-88534f5ba4f6", "metadata": {}, "outputs": [], "source": [ "prompt = \"How do I apply for a driver's license?\"\n", "agent_inputs = {\n", " \"input\": prompt,\n", " \"chat_history\": []\n", " }\n", "output = agent_executor.run(agent_inputs)\n", "print(output)" ] }, { "cell_type": "code", "execution_count": null, "id": "a9edfff6-d39c-4895-a0b3-e6c7d64464f3", "metadata": {}, "outputs": [], "source": [ "prompt = \"What's the latest news in Jordan?\"\n", "agent_inputs = {\n", " \"input\": prompt,\n", " \"chat_history\": []\n", " }\n", "output = agent_executor.run(agent_inputs)\n", "print(output)" ] }, { "cell_type": "code", "execution_count": null, "id": "7019d7db-e4a3-47c5-bbd1-9c86e587b635", "metadata": {}, "outputs": [], "source": [ "prompt = \"Compose and send an email to all the medicaid applicants that are missing income verification asking them to provide a pay stub from their employers\"\n", "agent_inputs = {\n", " \"input\": prompt,\n", " \"chat_history\": []\n", " }\n", "output = agent_executor.run(agent_inputs)\n", "print(output)" ] }, { "cell_type": "code", "execution_count": null, "id": "16d2212e-d5d9-4350-924f-f871b3725cc4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }