gemini/mcp/intro_to_mcp.ipynb (922 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "ur8xi4C7S06n" }, "outputs": [], "source": [ "# Copyright 2025 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "JAPoU8Sm5E6e" }, "source": [ "# Intro to Model Context Protocol (MCP) integration with Vertex AI\n", "\n", "<table align=\"left\">\n", " <td style=\"text-align: center\">\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\">\n", " <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fmcp%2Fintro_to_mcp.ipynb\">\n", " <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n", " </a>\n", " </td>\n", " <td style=\"text-align: center\">\n", " <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/mcp/intro_to_mcp.ipynb\">\n", " <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n", " </a>\n", " </td>\n", " \n", " \n", " <td style=\"text-align: center\">\n", " <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\">\n", " <img width=\"32px\" src=\"https://www.svgrepo.com/download/217753/github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n", " </a>\n", " </td>\n", "</table>\n", "\n", "<div style=\"clear: both;\"></div>\n", "\n", "<b>Share to:</b>\n", "\n", "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n", "</a>\n", "\n", "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n", "</a>\n", "\n", "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n", "</a>\n", "\n", "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n", "</a>\n", "\n", "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/mcp/intro_to_mcp.ipynb\" target=\"_blank\">\n", " <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n", "</a>" ] }, { "cell_type": "markdown", "metadata": { "id": "84f0f73a0f76" }, "source": [ "| Author |\n", "| --- |\n", "| [Dave Wang](https://github.com/wadave) |" ] }, { "cell_type": "markdown", "metadata": { "id": "tvgnzT1CKxrO" }, "source": [ "## Overview\n", "The Model Context Protocol (MCP) is an open standard that streamlines the integration of AI assistants with external data sources, tools, and systems. [MCP standardizes how applications provide context to LLMs](https://modelcontextprotocol.io/introduction). MCP establishes the essential standardized interface allowing AI models to connect directly with diverse external systems and services.\n", "\n", "Developers have the option to use third-party MCP servers or create custom ones when building applications. \n", "\n", "\n", "This notebook shows two ways to use MCP with Vertex AI\n", "- Build a custom MCP server, and use it with Gemini on Vertex AI\n", "- Use pre-built MCP server with Vertex AI" ] }, { "cell_type": "markdown", "metadata": { "id": "61RBz8LLbxCR" }, "source": [ "## Get started" ] }, { "cell_type": "markdown", "metadata": { "id": "No17Cw5hgx12" }, "source": [ "### Install Google Gen AI SDK and other required packages\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tFy3H3aPgx12" }, "outputs": [], "source": [ "%pip install --upgrade --quiet google-genai mcp geopy uv" ] }, { "cell_type": "markdown", "metadata": { "id": "dmWOrTJ3gx13" }, "source": [ "### Authenticate your notebook environment (Colab only)\n", "\n", "If you're running this notebook on Google Colab, run the cell below to authenticate your environment." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NyKGtVQjgx13" }, "outputs": [], "source": [ "import sys\n", "\n", "if \"google.colab\" in sys.modules:\n", " from google.colab import auth\n", "\n", " auth.authenticate_user()" ] }, { "cell_type": "markdown", "metadata": { "id": "DF4l8DTdWgPY" }, "source": [ "### Set Google Cloud project information\n", "\n", "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n", "\n", "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Nqwi-5ufWp_B" }, "outputs": [], "source": [ "# Use the environment variable if the user doesn't provide Project ID.\n", "import os\n", "\n", "from google import genai\n", "\n", "# TODO set up your own project id\n", "PROJECT_ID = \"[your-project-id]\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n", "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n", " PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n", "\n", "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n", "\n", "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)" ] }, { "cell_type": "markdown", "metadata": { "id": "5303c05f7aa6" }, "source": [ "### Import libraries" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "6fc324893334" }, "outputs": [], "source": [ "from typing import Any\n", "\n", "from google import genai\n", "from google.genai import types\n", "from mcp import ClientSession, StdioServerParameters\n", "from mcp.client.stdio import stdio_client" ] }, { "cell_type": "markdown", "metadata": { "id": "e43229f3ad4f" }, "source": [ "### Load model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "cf93d5f0ce00" }, "outputs": [], "source": [ "MODEL_ID = \"gemini-2.0-flash-001\" # @param {type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "1e7b40e87f22" }, "source": [ "### Create an MCP weather server\n", "The [Server development guide](https://modelcontextprotocol.io/quickstart/server) shows the details of creation of an MCP Server.\n", "\n", "Here we modify the server sample to include three tools:\n", "\n", "- Get weather alert by state\n", "- Get forecast by coordinates\n", "- Get forecast by city name" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "199464a1c1d4" }, "outputs": [], "source": [ "%%writefile server/weather_server.py\n", "import json\n", "from typing import Any, Dict, Optional\n", "import httpx\n", "from mcp.server.fastmcp import FastMCP\n", "from geopy.geocoders import Nominatim\n", "from geopy.exc import GeocoderTimedOut, GeocoderServiceError\n", "\n", "# Initialize FastMCP server\n", "mcp = FastMCP(\"weather\")\n", "\n", "# --- Configuration & Constants ---\n", "BASE_URL = \"https://api.weather.gov\"\n", "USER_AGENT = \"weather-agent\"\n", "REQUEST_TIMEOUT = 20.0\n", "GEOCODE_TIMEOUT = 10.0 # Timeout for geocoding requests\n", "\n", "# --- Shared HTTP Client ---\n", "http_client = httpx.AsyncClient(\n", " base_url=BASE_URL,\n", " headers={\"User-Agent\": USER_AGENT, \"Accept\": \"application/geo+json\"},\n", " timeout=REQUEST_TIMEOUT,\n", " follow_redirects=True,\n", ")\n", "\n", "# --- Geocoding Setup ---\n", "# Initialize the geocoder (Nominatim requires a unique user_agent)\n", "geolocator = Nominatim(user_agent=USER_AGENT)\n", "\n", "\n", "async def get_weather_response(endpoint: str) -> Optional[Dict[str, Any]]:\n", " \"\"\"\n", " Make a request to the NWS API using the shared client with error handling.\n", " Returns None if an error occurs.\n", " \"\"\"\n", " try:\n", " response = await http_client.get(endpoint)\n", " response.raise_for_status() # Raises HTTPStatusError for 4xx/5xx responses\n", " return response.json()\n", " except httpx.HTTPStatusError:\n", " # Specific HTTP errors (like 404 Not Found, 500 Server Error)\n", " return None\n", " except httpx.TimeoutException:\n", " # Request timed out\n", " return None\n", " except httpx.RequestError:\n", " # Other request errors (connection, DNS, etc.)\n", " return None\n", " except json.JSONDecodeError:\n", " # Response was not valid JSON\n", " return None\n", " except Exception:\n", " # Any other unexpected errors\n", " return None\n", "\n", "\n", "def format_alert(feature: Dict[str, Any]) -> str:\n", " \"\"\"Format an alert feature into a readable string.\"\"\"\n", " props = feature.get(\"properties\", {}) # Safer access\n", " # Use .get() with default values for robustness\n", " return f\"\"\"\n", " Event: {props.get('event', 'Unknown Event')}\n", " Area: {props.get('areaDesc', 'N/A')}\n", " Severity: {props.get('severity', 'N/A')}\n", " Certainty: {props.get('certainty', 'N/A')}\n", " Urgency: {props.get('urgency', 'N/A')}\n", " Effective: {props.get('effective', 'N/A')}\n", " Expires: {props.get('expires', 'N/A')}\n", " Description: {props.get('description', 'No description provided.').strip()}\n", " Instructions: {props.get('instruction', 'No instructions provided.').strip()}\n", " \"\"\"\n", "\n", "\n", "def format_forecast_period(period: Dict[str, Any]) -> str:\n", " \"\"\"Formats a single forecast period into a readable string.\"\"\"\n", " return f\"\"\"\n", " {period.get('name', 'Unknown Period')}:\n", " Temperature: {period.get('temperature', 'N/A')}°{period.get ('temperatureUnit', 'F')}\n", " Wind: {period.get('windSpeed', 'N/A')} {period.get('windDirection', 'N/A')}\n", " Short Forecast: {period.get('shortForecast', 'N/A')}\n", " Detailed Forecast: {period.get('detailedForecast', 'No detailed forecast provided.').strip()}\n", " \"\"\"\n", "\n", "\n", "# --- MCP Tools ---\n", "\n", "@mcp.tool()\n", "async def get_alerts(state: str) -> str:\n", " \"\"\"\n", " Get active weather alerts for a specific US state.\n", "\n", " Args:\n", " state: The two-letter US state code (e.g., CA, NY, TX). Case-insensitive.\n", " \"\"\"\n", " # Input validation and normalization\n", " if not isinstance(state, str) or len(state) != 2 or not state.isalpha():\n", " return \"Invalid input. Please provide a two-letter US state code (e.g., CA).\"\n", " state_code = state.upper()\n", "\n", " endpoint = f\"/alerts/active/area/{state_code}\"\n", " data = await get_weather_response(endpoint)\n", "\n", " if data is None:\n", " # Error occurred during request\n", " return f\"Failed to retrieve weather alerts for {state_code}.\"\n", "\n", " features = data.get(\"features\")\n", " if not features: # Handles both null and empty list\n", " return f\"No active weather alerts found for {state_code}.\"\n", "\n", " alerts = [format_alert(feature) for feature in features]\n", " return \"\\n---\\n\".join(alerts)\n", "\n", "\n", "@mcp.tool()\n", "async def get_forecast(latitude: float, longitude: float) -> str:\n", " \"\"\"\n", " Get the weather forecast for a specific location using latitude and longitude.\n", "\n", " Args:\n", " latitude: The latitude of the location (e.g., 34.05).\n", " longitude: The longitude of the location (e.g., -118.25).\n", " \"\"\"\n", " # Input validation\n", " if not (-90 <= latitude <= 90 and -180 <= longitude <= 180):\n", " return \"Invalid latitude or longitude provided. Latitude must be between -90 and 90, Longitude between -180 and 180.\"\n", "\n", " # NWS API requires latitude,longitude format with up to 4 decimal places\n", " point_endpoint = f\"/points/{latitude:.4f},{longitude:.4f}\"\n", " points_data = await get_weather_response(point_endpoint)\n", "\n", " if points_data is None or \"properties\" not in points_data:\n", " return f\"Unable to retrieve NWS gridpoint information for {latitude:.4f},{longitude:.4f}.\"\n", "\n", " # Extract forecast URLs from the gridpoint data\n", " forecast_url = points_data[\"properties\"].get(\"forecast\")\n", "\n", " if not forecast_url:\n", " return f\"Could not find the NWS forecast endpoint for {latitude:.4f},{longitude:.4f}.\"\n", "\n", " # Make the request to the specific forecast URL\n", " forecast_data = None\n", " try:\n", " response = await http_client.get(forecast_url)\n", " response.raise_for_status()\n", " forecast_data = response.json()\n", " except httpx.HTTPStatusError:\n", " pass # Error handled by returning None below\n", " except httpx.RequestError:\n", " pass # Error handled by returning None below\n", " except json.JSONDecodeError:\n", " pass # Error handled by returning None below\n", " except Exception:\n", " pass # Error handled by returning None below\n", "\n", " if forecast_data is None or \"properties\" not in forecast_data:\n", " return \"Failed to retrieve detailed forecast data from NWS.\"\n", "\n", " periods = forecast_data[\"properties\"].get(\"periods\")\n", " if not periods:\n", " return \"No forecast periods found for this location from NWS.\"\n", "\n", " # Format the first 5 periods\n", " forecasts = [format_forecast_period(period) for period in periods[:5]]\n", "\n", " return \"\\n---\\n\".join(forecasts)\n", "\n", "# --- NEW: get_forecast_by_city Tool ---\n", "@mcp.tool()\n", "async def get_forecast_by_city(city: str, state: str) -> str:\n", " \"\"\"\n", " Get the weather forecast for a specific US city and state by first finding its coordinates.\n", "\n", " Args:\n", " city: The name of the city (e.g., \"Los Angeles\", \"New York\").\n", " state: The two-letter US state code (e.g., CA, NY). Case-insensitive.\n", " \"\"\"\n", " # --- Input Validation ---\n", " if not city or not isinstance(city, str):\n", " return \"Invalid city name provided.\"\n", " if (\n", " not state\n", " or not isinstance(state, str)\n", " or len(state) != 2\n", " or not state.isalpha()\n", " ):\n", " return \"Invalid state code. Please provide the two-letter US state abbreviation (e.g., CA).\"\n", "\n", " city_name = city.strip()\n", " state_code = state.strip().upper()\n", " # Construct a query likely to yield a US result\n", " query = f\"{city_name}, {state_code}, USA\"\n", "\n", " # --- Geocoding ---\n", " location = None\n", " try:\n", " # Synchronous geocode call\n", " location = geolocator.geocode(query, timeout=GEOCODE_TIMEOUT)\n", "\n", " except GeocoderTimedOut:\n", " return f\"Could not get coordinates for '{city_name}, {state_code}': The location service timed out.\"\n", " except GeocoderServiceError:\n", " return f\"Could not get coordinates for '{city_name}, {state_code}': The location service returned an error.\"\n", " except Exception:\n", " # Catch any other unexpected errors during geocoding\n", " return f\"An unexpected error occurred while finding coordinates for '{city_name}, {state_code}'.\"\n", "\n", " # --- Handle Geocoding Result ---\n", " if location is None:\n", " return f\"Could not find coordinates for '{city_name}, {state_code}'. Please check the spelling or try a nearby city.\"\n", "\n", " latitude = location.latitude\n", " longitude = location.longitude\n", "\n", " # --- Reuse existing forecast logic with obtained coordinates ---\n", " return await get_forecast(latitude, longitude)\n", "\n", "\n", "# --- Server Execution & Shutdown ---\n", "async def shutdown_event():\n", " \"\"\"Gracefully close the httpx client.\"\"\"\n", " await http_client.aclose()\n", " # print(\"HTTP client closed.\") # Optional print statement if desired\n", "\n", "if __name__ == \"__main__\":\n", " mcp.run(transport=\"stdio\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4ebb6082d936" }, "source": [ "### Gemini agent loop\n", "\n", "Within an MCP client session, this agent loop runs a multi-turn conversation loop with a Gemini model, handling tool calls via MCP server.\n", "\n", "This function orchestrates the interaction between a user prompt, a Gemini model capable of function calling, and a session object that provides and executes tools. It handles the cycle of:\n", "- Gemini gets tool information from MCP client session\n", "- Sending the user prompt (and conversation history) to the model.\n", "- If the model requests tool calls, Gemini makes initial function calls to get structured data as per schema, and \n", "- Sending the tool execution results back to the model.\n", "- Repeating until the model provides a text response or the maximum number of tool execution turns is reached.\n", "- Gemini generates final response based on tool responses and original query.\n", " \n", "MCP integration with Gemini\n", "\n", "<img src=\"https://storage.googleapis.com/github-repo/generative-ai/gemini/mcp/mcp_tool_call.png\" alt=\"MCP with Gemini\" height=\"700\">" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ae7abde705ca" }, "outputs": [], "source": [ "# --- Configuration ---\n", "# Consider using a more recent/recommended model if available and suitable\n", "DEFAULT_MAX_TOOL_TURNS = 5 # Maximum consecutive turns for tool execution\n", "DEFAULT_INITIAL_TEMPERATURE = (\n", " 0.0 # Temperature for the first LLM call (more deterministic)\n", ")\n", "DEFAULT_TOOL_CALL_TEMPERATURE = (\n", " 1.0 # Temperature for LLM calls after tool use (potentially more creative)\n", ")\n", "\n", "\n", "# Make tool calls via MCP Server\n", "async def _execute_tool_calls(\n", " function_calls: list[types.FunctionCall], session: ClientSession\n", ") -> list[types.Part]:\n", " \"\"\"\n", " Executes a list of function calls requested by the Gemini model via the session.\n", "\n", " Args:\n", " function_calls: A list of FunctionCall objects from the model's response.\n", " session: The session object capable of executing tools via `call_tool`.\n", "\n", " Returns:\n", " A list of Part objects, each containing a FunctionResponse corresponding\n", " to the execution result of a requested tool call.\n", " \"\"\"\n", " tool_response_parts: list[types.Part] = []\n", " print(f\"--- Executing {len(function_calls)} tool call(s) ---\")\n", "\n", " for func_call in function_calls:\n", " tool_name = func_call.name\n", " # Ensure args is a dictionary, even if missing or not a dict type\n", " args = func_call.args if isinstance(func_call.args, dict) else {}\n", " print(f\" Attempting to call session tool: '{tool_name}' with args: {args}\")\n", "\n", " tool_result_payload: dict[str, Any]\n", " try:\n", " # Execute the tool using the provided session object\n", " # Assumes session.call_tool returns an object with attributes\n", " # like `isError` (bool) and `content` (list of Part-like objects).\n", " tool_result = await session.call_tool(tool_name, args)\n", " print(f\" Session tool '{tool_name}' execution finished.\")\n", "\n", " # Extract result or error message from the tool result object\n", " result_text = \"\"\n", " # Check structure carefully based on actual `session.call_tool` return type\n", " if (\n", " hasattr(tool_result, \"content\")\n", " and tool_result.content\n", " and hasattr(tool_result.content[0], \"text\")\n", " ):\n", " result_text = tool_result.content[0].text or \"\"\n", "\n", " if hasattr(tool_result, \"isError\") and tool_result.isError:\n", " error_message = (\n", " result_text\n", " or f\"Tool '{tool_name}' failed without specific error message.\"\n", " )\n", " print(f\" Tool '{tool_name}' reported an error: {error_message}\")\n", " tool_result_payload = {\"error\": error_message}\n", " else:\n", " print(\n", " f\" Tool '{tool_name}' succeeded. Result snippet: {result_text[:150]}...\"\n", " ) # Log snippet\n", " tool_result_payload = {\"result\": result_text}\n", "\n", " except Exception as e:\n", " # Catch exceptions during the tool call itself\n", " error_message = f\"Tool execution framework failed: {type(e).__name__}: {e}\"\n", " print(f\" Error executing tool '{tool_name}': {error_message}\")\n", " tool_result_payload = {\"error\": error_message}\n", "\n", " # Create a FunctionResponse Part to send back to the model\n", " tool_response_parts.append(\n", " types.Part.from_function_response(\n", " name=tool_name, response=tool_result_payload\n", " )\n", " )\n", " print(f\"--- Finished executing tool call(s) ---\")\n", " return tool_response_parts\n", "\n", "\n", "async def run_agent_loop(\n", " prompt: str,\n", " client: genai.Client,\n", " session: ClientSession,\n", " model_id: str = MODEL_ID,\n", " max_tool_turns: int = DEFAULT_MAX_TOOL_TURNS,\n", " initial_temperature: float = DEFAULT_INITIAL_TEMPERATURE,\n", " tool_call_temperature: float = DEFAULT_TOOL_CALL_TEMPERATURE,\n", ") -> types.GenerateContentResponse:\n", " \"\"\"\n", " Runs a multi-turn conversation loop with a Gemini model, handling tool calls.\n", "\n", " This function orchestrates the interaction between a user prompt, a Gemini\n", " model capable of function calling, and a session object that provides\n", " and executes tools. It handles the cycle of:\n", " 1. Sending the user prompt (and conversation history) to the model.\n", " 2. If the model requests tool calls, executing them via the `session`.\n", " 3. Sending the tool execution results back to the model.\n", " 4. Repeating until the model provides a text response or the maximum\n", " number of tool execution turns is reached.\n", "\n", " Args:\n", " prompt: The initial user prompt to start the conversation.\n", " client: An initialized Gemini GenerativeModel client object\n", "\n", " session: An active session object responsible for listing available tools\n", " via `list_tools()` and executing them via `call_tool(tool_name, args)`.\n", " It's also expected to have an `initialize()` method.\n", " model_id: The identifier of the Gemini model to use (e.g., \"gemini-2.0-flash\").\n", " max_tool_turns: The maximum number of consecutive turns dedicated to tool calls\n", " before forcing a final response or exiting.\n", " initial_temperature: The temperature setting for the first model call.\n", " tool_call_temperature: The temperature setting for subsequent model calls\n", " that occur after tool execution.\n", "\n", " Returns:\n", " The final Response from the Gemini model after the\n", " conversation loop concludes (either with a text response or after\n", " reaching the max tool turns).\n", "\n", " Raises:\n", " ValueError: If the session object does not provide any tools.\n", " Exception: Can potentially raise exceptions from the underlying API calls\n", " or session tool execution if not caught internally by `_execute_tool_calls`.\n", " \"\"\"\n", " print(\n", " f\"Starting agent loop with model '{model_id}' and prompt: '{prompt[:100]}...'\"\n", " )\n", "\n", " # Initialize conversation history with the user's prompt\n", " contents: list[types.Content] = [\n", " types.Content(role=\"user\", parts=[types.Part(text=prompt)])\n", " ]\n", "\n", " # Ensure the session is ready (if needed)\n", " if hasattr(session, \"initialize\") and callable(session.initialize):\n", " print(\"Initializing session...\")\n", " await session.initialize()\n", " else:\n", " print(\"Session object does not have an initialize() method, proceeding anyway.\")\n", "\n", " # --- 1. Discover Tools from Session ---\n", " print(\"Listing tools from session...\")\n", " # Assumes session.list_tools() returns an object with a 'tools' attribute (list)\n", " # Each item in the list should have 'name', 'description', and 'inputSchema' attributes.\n", " session_tool_list = await session.list_tools()\n", "\n", " if not session_tool_list or not session_tool_list.tools:\n", " raise ValueError(\"No tools provided by the session. Agent loop cannot proceed.\")\n", "\n", " # Convert session tools to the format required by the Gemini API\n", " gemini_tool_config = types.Tool(\n", " function_declarations=[\n", " types.FunctionDeclaration(\n", " name=tool.name,\n", " description=tool.description,\n", " parameters=tool.inputSchema, # Assumes inputSchema is compatible\n", " )\n", " for tool in session_tool_list.tools\n", " ]\n", " )\n", " print(\n", " f\"Configured Gemini with {len(gemini_tool_config.function_declarations)} tool(s).\"\n", " )\n", "\n", " # --- 2. Initial Model Call ---\n", " print(\"Making initial call to Gemini model...\")\n", " current_temperature = initial_temperature\n", " response = await client.aio.models.generate_content(\n", " model=MODEL_ID,\n", " contents=contents, # Send updated history\n", " config=types.GenerateContentConfig(\n", " temperature=1.0,\n", " tools=[gemini_tool_config],\n", " ), # Keep sending same config\n", " )\n", " print(\"Initial response received.\")\n", "\n", " # Append the model's first response (potentially including function calls) to history\n", " # Need to handle potential lack of candidates or content\n", " if not response.candidates:\n", " print(\"Warning: Initial model response has no candidates.\")\n", " # Decide how to handle this - raise error or return the empty response?\n", " return response\n", " contents.append(response.candidates[0].content)\n", "\n", " # --- 3. Tool Calling Loop ---\n", " turn_count = 0\n", " # Check specifically for FunctionCall objects in the latest response part\n", " latest_content = response.candidates[0].content\n", " has_function_calls = any(part.function_call for part in latest_content.parts)\n", "\n", " while has_function_calls and turn_count < max_tool_turns:\n", " turn_count += 1\n", " print(f\"\\n--- Tool Turn {turn_count}/{max_tool_turns} ---\")\n", "\n", " # --- 3.1 Execute Pending Function Calls ---\n", " function_calls_to_execute = [\n", " part.function_call for part in latest_content.parts if part.function_call\n", " ]\n", " tool_response_parts = await _execute_tool_calls(\n", " function_calls_to_execute, session\n", " )\n", "\n", " # --- 3.2 Add Tool Responses to History ---\n", " # Send back the results for *all* function calls from the previous turn\n", " contents.append(\n", " types.Content(role=\"function\", parts=tool_response_parts)\n", " ) # Use \"function\" role\n", " print(f\"Added {len(tool_response_parts)} tool response part(s) to history.\")\n", "\n", " # --- 3.3 Make Subsequent Model Call with Tool Responses ---\n", " print(\"Making subsequent API call to Gemini with tool responses...\")\n", " current_temperature = tool_call_temperature # Use different temp for follow-up\n", " response = await client.aio.models.generate_content(\n", " model=MODEL_ID,\n", " contents=contents, # Send updated history\n", " config=types.GenerateContentConfig(\n", " temperature=1.0,\n", " tools=[gemini_tool_config],\n", " ),\n", " )\n", " print(\"Subsequent response received.\")\n", "\n", " # --- 3.4 Append latest model response and check for more calls ---\n", " if not response.candidates:\n", " print(\"Warning: Subsequent model response has no candidates.\")\n", " break # Exit loop if no candidates are returned\n", " latest_content = response.candidates[0].content\n", " contents.append(latest_content)\n", " has_function_calls = any(part.function_call for part in latest_content.parts)\n", " if not has_function_calls:\n", " print(\n", " \"Model response contains text, no further tool calls requested this turn.\"\n", " )\n", "\n", " # --- 4. Loop Termination Check ---\n", " if turn_count >= max_tool_turns and has_function_calls:\n", " print(\n", " f\"Maximum tool turns ({max_tool_turns}) reached. Exiting loop even though function calls might be pending.\"\n", " )\n", " elif not has_function_calls:\n", " print(\"Tool calling loop finished naturally (model provided text response).\")\n", "\n", " # --- 5. Return Final Response ---\n", " print(\"Agent loop finished. Returning final response.\")\n", " return response" ] }, { "cell_type": "markdown", "metadata": { "id": "41a2e0fc0dfa" }, "source": [ "## 1. Use your own MCP Server\n", "### Start MCP client session with Custom MCP server and Gemini agent loop" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "b556fbb2be66" }, "outputs": [], "source": [ "# Create server parameters for stdio connection\n", "weather_server_params = StdioServerParameters(\n", " command=\"python\",\n", " # Make sure to update to the full absolute path to your weather_server.py file\n", " args=[\"./server/weather_server.py\"],\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "bd8512e5a8c2" }, "outputs": [], "source": [ "async def run():\n", " async with stdio_client(weather_server_params) as (read, write):\n", " async with ClientSession(\n", " read,\n", " write,\n", " ) as session:\n", " # Test prompt\n", " prompt = \"Tell me about weather in LA, CA\"\n", " print(f\"Running agent loop with prompt: {prompt}\")\n", " # Run agent loop\n", " res = await run_agent_loop(prompt, client, session)\n", " return res\n", "\n", "\n", "res = await run()\n", "print(res.text)" ] }, { "cell_type": "markdown", "metadata": { "id": "f07ab426ca0c" }, "source": [ "## 2. Use pre-built MCP server\n", "\n", "There are [pre-built MCP servers](https://github.com/modelcontextprotocol/servers?tab=readme-ov-file) available for use.\n", "\n", "Here we use [this](https://github.com/LucasHild/mcp-server-bigquery) as an example.\n", "\n", "It has three tools:\n", "\n", "- execute-query: Executes a SQL query using BigQuery \n", "- list-tables: Lists all tables in the BigQuery database\n", "- describe-table: Describes the schema of a specific table\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "e3f1d09b69ac" }, "outputs": [], "source": [ "# Create server parameters for stdio connection\n", "bq_server_params = StdioServerParameters(\n", " command=\"uvx\", # Executable\n", " args=[\"mcp-server-bigquery\", \"--project\", PROJECT_ID, \"--location\", LOCATION],\n", " env=None, # Optional environment variables\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "46fd766cd2f3" }, "outputs": [], "source": [ "async def run():\n", " async with stdio_client(bq_server_params) as (read, write):\n", " async with ClientSession(\n", " read,\n", " write,\n", " ) as session:\n", " # Test prompt\n", " prompt = \"Please list my BigQuery tables\"\n", " print(f\"Running agent loop with prompt: {prompt}\")\n", " # Run agent loop\n", " res = await run_agent_loop(prompt, client, session)\n", " return res\n", "\n", "\n", "res = await run()\n", "print(res.text)" ] }, { "cell_type": "markdown", "metadata": { "id": "3b513a8c9470" }, "source": [ "References:\n", "- https://modelcontextprotocol.io/introduction\n", "- https://github.com/philschmid/gemini-samples/blob/main/examples/gemini-mcp-example.ipynb\n", "- https://github.com/modelcontextprotocol/python-sdk \n", " " ] } ], "metadata": { "colab": { "name": "intro_to_mcp.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }