In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Running a Gemma 2-based agentic RAG with Ollama on Vertex AI and LangGraph

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fserving%2Fvertex_ai_ollama_gemma2_rag_agent.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb">
      <img width="32px" src="https://www.svgrepo.com/download/217753/github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| | |
|-|-|
| Author(s) |  [Ivan Nardini](https://github.com/inardini) |

## Overview

> [**Gemma 2**](https://ai.google.dev/gemma) is a new generation of open models developed by Google. It offers pre-trained and instruction-tuned variants in two sizes (2B, 7B and 9b parameters), designed for high performance and efficiency on various hardware.  Gemma 2 models are available through platforms like Google AI Studio, Kaggle, and Hugging Face.

> [**Ollama**](https://github.com/ollama/ollama) is a tool for running open-source large language models (LLMs) locally.  It simplifies LLM usage by bundling model weights, configurations, and datasets into a single package managed by a [`Modelfile`](https://github.com/ollama/ollama/blob/main/docs/modelfile.md). Ollama supports various models like LLaMA-2, Mistral, and CodeLLaMA, and is compatible with macOS and Linux.

> [**LangGraph**](https://python.langchain.com/en/latest/modules/graphs/langgraph.html) is a framework developed by LangChain for building applications with complex workflows, including agents and multi-agent systems. It offers precise control over application flow and state, supporting cyclical graphs and advanced state management.  LangGraph enhances LangChain's capabilities, providing more flexibility for agentic applications.

> [**Google Vertex AI**](https://cloud.google.com/vertex-ai) is Google Cloud's unified machine learning (ML) platform.  It provides a comprehensive suite of tools for building, training, deploying, and managing ML models and AI applications, including large language models (LLMs). Vertex AI streamlines the entire ML workflow, from data management to prediction, and supports customization for specific business needs.

This notebook showcases how to run a Gemma 2-based Agent with Ollama on Vertex AI and LangGraph.

By the end of this notebook, you will learn how to:

- Deploy Google Gemma 2 on Vertex AI using Ollama
- Learn how to test the container using Vertex AI LocalModel class
- Implement a simple RAG agent application with Gemma 2 and Ollama using LangGraph

## Get started

### Install Vertex AI SDK and other required packages


In [None]:
%pip install --upgrade --user --quiet "huggingface_hub" \
    "google-cloud-aiplatform[prediction]" \
    "torch" \
    "etils" \
    "crcmod"

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>‚ö†Ô∏è The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says "Your session crashed for an unknown reason." This is expected. Wait until it's finished before continuing to the next step. ‚ö†Ô∏è</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
# import sys

# if "google.colab" in sys.modules:
#     from google.colab import auth

#     auth.authenticate_user()

### Requirements

You will need to have the following IAM roles set:

- Artifact Registry Administrator (`roles/artifactregistry.admin`)
- Cloud Build Editor (`roles/cloudbuild.builds.editor`)
- Vertex AI User (`roles/aiplatform.user`)
- Service Account User (`roles/iam.serviceAccountUser`)
- Service Usage Consumer (`roles/serviceusage.serviceUsageConsumer`)
- Storage Admin (`roles/storage.admin`)

For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).

---

You will also need to enable the following APIs (if not enabled already):

- Artifact Registry API (artifactregistry.googleapis.com)
- Vertex AI API (aiplatform.googleapis.com)
- Compute Engine API (compute.googleapis.com)

For more information about API enablement, see [Enabling APIs](https://cloud.google.com/apis/docs/getting-started#enabling_apis).

### Authenticate your Hugging Face account

As [`google/gemma-2b-it`](https://huggingface.co/google/gemma-2-2b-it) is a gated model, you need to have a Hugging Face Hub account, and accept the Google's usage license for Gemma. Once that's done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub.

> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-9b-it`](https://huggingface.co/google/gemma-2-2b-it).

Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`.

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens).

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
# Use the environment variable if the user doesn't provide Project ID.
import os

import vertexai

PROJECT_ID = "[your-project-id]"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}

if not PROJECT_ID or PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

BUCKET_NAME = "[your-bucket-name]"  # @param {type: "string", placeholder: "[your-bucket-name]", isTemplate: true}

if not BUCKET_NAME or BUCKET_NAME == "[your-bucket-name]":
    BUCKET_NAME = f"{PROJECT_ID}-bucket"

BUCKET_URI = f"gs://{BUCKET_NAME}"

! gsutil mb -p $PROJECT_ID -l $LOCATION $BUCKET_URI

vertexai.init(project=PROJECT_ID, location=LOCATION, staging_bucket=BUCKET_URI)

### Set tutorial folder

Define a folder for the tutorial.

In [None]:
from etils import epath

TUTORIAL_DIR = epath.Path("ollama_on_vertex_ai_tutorial")
BUILD_DIR = TUTORIAL_DIR / "build"
MODELS_DIR = BUILD_DIR / "ollama_models"

MODELS_DIR.mkdir(exist_ok=True, parents=True)

### Import libraries

Import main libraries.

In [None]:
import gc
import json

from google.cloud import aiplatform
from google.cloud.aiplatform import Endpoint, Model
from google.cloud.aiplatform.prediction import LocalModel
from huggingface_hub import get_token, snapshot_download
import torch

### Libraries settings

In [None]:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

### Helpers

Define some helpers.

In [None]:
def get_cuda_device_names():
    """A function to get the list of NVIDIA GPUs"""
    if not torch.cuda.is_available():
        return None

    return [str(i) for i in range(torch.cuda.device_count())]


def empty_gpu_ram():
    gc.collect()
    torch.cuda.empty_cache()

## Deploy Gemma 2 using Ollama on Vertex AI Prediction

To deploy Gemma 2 as an Ollama model on Vertex AI Prediction, a custom container with the Ollama server and the Gemma 2 model is required. You can use Cloud Build, a serverless CI/CD platform, to build the serving container image.

### Download Gemma 2 from Hugging Face Hub

Download `google-cloud-partnership/gemma-2-2b-it-lora-sql` , a Gemma 2 adapter which allows you to handle both SQL user requests using Gemma 2.

In [None]:
base_model_id = "google-cloud-partnership/gemma-2-2b-it-lora-sql"
model_dir = MODELS_DIR / "gemma-2-2b-it-lora-sql"

ignore_patterns = [".gitattributes", ".gitkeep", "*.md"]

snapshot_download(
    repo_id=base_model_id,
    token=get_token(),
    local_dir=model_dir,
    local_dir_use_symlinks=False,
    ignore_patterns=ignore_patterns,
)

! rm -rf $model_dir/.cache

### Create Artifact Registry repository

To build a container, create a repository in Google Cloud Artifact Registry.

In [None]:
REPOSITORY_NAME = "ollama-gemma-on-vertex"

In [None]:
!gcloud artifacts repositories create $REPOSITORY_NAME \
      --repository-format=docker \
      --location=$LOCATION \
      --project=$PROJECT_ID

### Create a Dockerfile

Use the following Dockerfile to define the container's build steps. The Dockerfile installs Python and Flask, sets environment variables, copies Ollama model files, exposes ports, and runs the Ollama model and a proxy server.

> In this scenario, both Ollama and Fast API in the same container for simplicity.

In [None]:
dockerfile = """
# Use multi-stage build for a smaller final image
FROM ollama/ollama:0.5.5

# Install Python and FastAPI
RUN apt-get update && \
    apt-get install -y python3 python3-pip curl && \
    pip3 install fastapi uvicorn httpx asyncio

# Set build-time arguments for better flexibility
ARG OLLAMA_PORT=8079
ARG SERVING_PORT=8080

# Set environment variables
ENV OLLAMA_HOST=0.0.0.0:${OLLAMA_PORT} \
    OLLAMA_MODELS=/ollama_models \
    OLLAMA_KEEP_ALIVE=-1 \
    OLLAMA_DEBUG=false

# Copy model files
COPY ./ollama_models /ollama_models
COPY gemma-2-2b-it-lora-sql.modelfile .

# Expose ollama port
EXPOSE ${OLLAMA_PORT}

# Create model in a proper way with health check
RUN ollama serve & \
    sleep 5 && ollama create gemma-2-2b-it-lora-sql-2b -f gemma-2-2b-it-lora-sql.modelfile

# Expose port
EXPOSE ${SERVING_PORT}

# Copy the proxy server code and entrypoint script
COPY main.py .
COPY entrypoint.sh .

# Run the proxy server
RUN chmod +x ./entrypoint.sh
ENTRYPOINT ["./entrypoint.sh"]
"""

with BUILD_DIR.joinpath("Dockerfile").open("w") as f:
    f.write(dockerfile)

### Create Modelfile

Define an Ollama Modelfile which is the configuration file Ollama needs to define and use the Gemma 2 adapter model.

In [None]:
modelfile = """FROM gemma2:2b
ADAPTER ollama_models/gemma-2-2b-it-lora-sql
"""

with BUILD_DIR.joinpath("gemma-2-2b-it-lora-sql.modelfile").open("w") as f:
    f.write(modelfile)

### Serve engine proxy

This FastAPI application serves as a proxy between Vertex AI Endpoint and a local Ollama model server.

It receives prediction requests from Vertex AI, forwards them to Ollama, and returns the responses to Vertex AI in a standardized format. The application also includes health checks, request validation, error handling, and asynchronous API calls.

> In this scenario, the FastAPI application only maps the `generate` API.

In [None]:
app_module = """
'''
FastAPI proxy for Vertex AI Endpoint running Ollama.
'''

import os
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
import httpx
from fastapi import FastAPI, HTTPException, status
from fastapi.responses import JSONResponse
import asyncio


# Configuration
class PredictionRequest(BaseModel):
    '''Request model for predictions'''
    instances: List[Dict] = Field(..., description="List of prediction instances")

class PredictionResponse(BaseModel):
    '''Response model for predictions'''
    predictions: List[str] = Field(..., description="List of model responses")

class Config:
    '''Application configuration.'''
    HEALTH_ROUTE: str = os.environ.get('AIP_HEALTH_ROUTE', '/health')
    PREDICT_ROUTE: str = os.environ.get('AIP_PREDICT_ROUTE', '/predict')
    PORT: int = int(os.environ.get('AIP_HTTP_PORT', '8080'))
    OLLAMA_URL: str = os.environ.get('OLLAMA_URL', 'http://localhost:8079')
    MODEL_NAME: str = os.environ.get('MODEL_NAME', 'gemma-2-2b-it-lora-sql-2b')
    TIMEOUT: int = int(os.environ.get('TIMEOUT_SECONDS', '30'))

# Helper function
async def ollama_generate(prompt: str, parameters: Dict['str', Any]) -> str:
    '''
    Make a prediction using the Ollama model.
    '''
    async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:
        try:
            response = await client.post(
                f"{Config.OLLAMA_URL}/api/generate",
                json={
                    "prompt": prompt,
                    "stream": False,
                    "options": parameters,
                    "model": Config.MODEL_NAME
                }
            )
            response.raise_for_status()
            return response.json()["response"]

        except httpx.HTTPError as e:
            raise HTTPException(
                status_code=status.HTTP_502_BAD_GATEWAY,
                detail=f"Error calling Ollama: {str(e)}"
            )
# Application
app = FastAPI(
    title="Ollama Vertex AI Proxy",
    description="A proxy service to run Ollama models on Vertex AI"
)

@app.get(
    Config.HEALTH_ROUTE,
    response_model=Dict[str, str],
    description="Health check endpoint",
)
async def health() -> Dict[str, str]:
    '''Check if the service is healthy.'''
    return {'status': 'healthy'}

@app.post(
    Config.PREDICT_ROUTE,
    response_model=PredictionResponse,
    description="Make predictions using the Ollama model",
)
async def predict(request: PredictionRequest) -> PredictionResponse:
    '''Process predictions using the Ollama model concurrently.'''

    if not request.instances:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="No instances provided in request"
        )

    try:
        # Process all prompts concurrently
        tasks = []
        for instance in request.instances:
            prompt = instance.get('inputs', '')
            parameters = instance.get('parameters', {})
            tasks.append(ollama_generate(prompt, parameters))
            
        # Wait for all requests to complete
        predictions = await asyncio.gather(*tasks)
        return PredictionResponse(predictions=predictions)

    except Exception as e:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error processing prediction: {str(e)}"
        )

@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
    '''Handle HTTP exceptions with a consistent format.'''
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "error": {
                "code": exc.status_code,
                "message": exc.detail
            }
        }
    )

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(
        "main:app",
        host="0.0.0.0",
        port=Config.PORT
    )
"""

with BUILD_DIR.joinpath("main.py").open("w") as f:
    f.write(app_module)

### Entrypoint script

Create an entrypoint script to startup FastAPI application and its Ollama service. The script launches Ollama (a local AI model server) in the background, verifies its readiness through health checks, and then initializes a FastAPI application to serve as the main interface.

In [None]:
entrypoint_script = """#!/bin/bash

# Enable error handling
set -e

# Function to log messages with timestamps
log() {
    echo "[$(date +'%Y-%m-%d %H:%M:%S')] $1"
}

# Function to check if Ollama is ready
check_ollama() {
    for i in {1..30}; do
        if curl -s http://localhost:8079 >/dev/null; then
            log "‚úÖ Ollama is ready!"
            return 0
        fi
        log "‚è≥ Waiting for Ollama to start... ($i/30)"
        sleep 1
    done
    log "‚ùå Ollama failed to start within 30 seconds"
    return 1
}

# Start Ollama in the background
log "üöÄ Starting Ollama..."
ollama serve & sleep 5

# Wait for Ollama to be ready
check_ollama

# Start the FastAPI serving application
log "üöÄ Starting FastAPI serving application..."
exec python3 /main.py
"""

with BUILD_DIR.joinpath("entrypoint.sh").open("w") as f:
    f.write(entrypoint_script)

### Build the container image with Cloud Build

Use Cloud Build to build the container image.

> The operation will take less than 5 minutes.


In [None]:
SERVING_CONTAINER_IMAGE_URI = (
    f"{LOCATION}-docker.pkg.dev/{PROJECT_ID}/{REPOSITORY_NAME}/ollama-gemma-2-serve"
)

! gcloud auth configure-docker $LOCATION-docker.pkg.dev --quiet
! gcloud builds submit --tag $SERVING_CONTAINER_IMAGE_URI --project $PROJECT_ID --machine-type e2-highcpu-32 $BUILD_DIR

### (Optional) Testing Ollama container locally using serving container with Vertex AI LocalModel

For debugging purpose, Vertex AI provides `LocalModel` class, accessible through the Vertex AI SDK for Python. This class allows you to build and deploy your model locally, simulating the Vertex AI environment. Using LocalModel involves creating a Docker image that encapsulates your custom predictor code and the associated handler.

> **Important**: Running the LocalModel class requires a local Docker installation. This allows the model to be encapsulated within a container for consistent execution across different environments.

> If you haven't already installed Docker Engine, please refer to the official installation guide: [Install Docker Engine](https://docs.docker.com/engine/install/). This documentation provides detailed instructions for various operating systems and will guide you through the installation process. Ensure Docker is running correctly before proceeding with the LocalModel examples.


#### Create a LocalModel instance

Set up a local model by specifying the container image to use and the port it will listen on (8080).


In [None]:
local_ollama_gemma_model = LocalModel(
    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,
    serving_container_ports=[8080],
)

#### Create a LocalEndpoint instance

Deploy the model to a local endpoint for serving. The `gpu_device_ids` sets available GPUs if present.


In [None]:
local_ollama_gemma_endpoint = local_ollama_gemma_model.deploy_to_local_endpoint(
    gpu_device_ids=get_cuda_device_names()
)

local_ollama_gemma_endpoint.serve()

#### Monitoring Your Containerized Deployment

To keep track of your container's deployment progress and identify any potential issues, you can use the following Docker commands within your terminal:

1. **List all containers:** `docker container ls -a` displays a list of all running and stopped containers. Locate the container associated with your local endpoint and copy its ID.  This ID is essential for the next step.

2. **Stream container logs:** `docker logs --follow <CONTAINER_ID>`  provides a real-time stream of your container's logs. Replace `<CONTAINER_ID>` with the ID you copied in the previous step. Monitoring these logs allows you to observe the deployment process, identify any errors or warnings, and understand the container's overall health.

#### Generate predictions

Send a prediction request to a local Vertex AI endpoint.

You convert the request data into a JSON string, send it to the endpoint, and then print the predictions from the JSON response.


In [None]:
prediction_request = {
    "instances": [
        {
            "inputs": "How to run a select all query",
            "parameters": {
                "temperature": 1.0,
            },
        },
    ]
}

In [None]:
vertex_prediction_request = json.dumps(prediction_request)
vertex_prediction_response = local_ollama_gemma_endpoint.predict(
    request=vertex_prediction_request, headers={"Content-Type": "application/json"}
)
print(vertex_prediction_response.json()["predictions"])

In [None]:
vertex_prediction_response

### Register Ollama model on Vertex AI

To serve Gemma 2 with Ollama on Vertex AI, import the model on Vertex AI Model Registry, a central repository where you can manage the lifecycle of your ML models on Vertex AI, using the `aiplatform.Model.upload` method.

Some of the main arguments of the `aiplatform.Model.upload` are:

- `display_name`: The name shown in the Vertex AI Model Registry.
- `serving_container_image_uri`: The location of the Ollama container.
- (Optional) `serving_container_ports`: The port where the Vertex AI endpoint will be exposed (default 8080).

For more information on the supported `aiplatform.Model.upload` arguments, check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_upload).

In [None]:
model = Model.upload(
    display_name="google--gemma-2-2b-it-lora-sql-ollama",
    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,
    serving_container_ports=[8080],
)
model.wait()

### Deploy Ollama model on Vertex AI

After the model is registered on Vertex AI, deploy the model to a Vertex AI endpoint using the `aiplatform.Model.deploy` method.

Some of the main arguments of the `aiplatform.Model.upload` are:

* (optional) **`endpoint`** : Set an endpoint for model deployment.
* **`machine_type, accelerator_type, accelerator_count`** : Define the deployment instance and accelerator configuration.

For more information on the supported `aiplatform.Model.deploy` arguments, you can check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_deploy).

> Note that the model deployment on Vertex AI can take around 15 to 25 minutes; most of the time being the allocation / reservation of the resources, setting up the network and security, and such.

In [None]:
endpoint = Endpoint.create(
    display_name="google--gemma-2-2b-it-lora-sql-ollama-endpoint"
)

deployed_model = model.deploy(
    endpoint=endpoint,
    machine_type="g2-standard-4",
    accelerator_type="NVIDIA_L4",
    accelerator_count=1,
)

### Online predictions on Vertex AI

Once the model is deployed on Vertex AI, run the online predictions using the `aiplatform.Endpoint.predict` method, which will send the requests to the running endpoint in the `/predict` route specified within the container following Vertex AI I/O payload formatting.

In [None]:
output = deployed_model.predict(
    instances=[
        {
            "inputs": "How to run a select all query",
            "parameters": {
                "temperature": 1.0,
            },
        },
    ]
)
predictions = output.predictions
print(predictions[0])

## Build an agentic RAG application using Ollama model on Vertex AI with LangGraph

After deployed the Ollama model on Vertex AI, consume the model to build an agentic RAG application using Ollama model on Vertex AI with LangGraph.

### Install additional libraries

Install langgraph libraries.

In [None]:
%pip install --upgrade --user --quiet "langchain-community" \
    "langchainhub" \
    "langchain_google_vertexai" \
    "langgraph" \
    "faiss-gpu"

### Import additional libraries

Import additional libraries to build the agent.

In [None]:
from typing import Any, TypedDict

from IPython.display import Image, display
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import FAISS
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.runnables.graph import MermaidDrawMethod
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph

### Define some helpers

Define a `CustomVertexAIModel` to handle all the endpoint formatting and parameter management in a way to make the model compatible with a LangGraph agentic workflow.

In [None]:
class CustomVertexAIModel:
    """
    A simple wrapper for Vertex AI Endpoints.
    """

    def __init__(
        self,
        project: str,
        location: str,
        endpoint_id: str,
        **model_params: Any,
    ):
        """
        Initialize the Model Garden client.

        Args:
            project: Google Cloud project ID
            location: Model location (e.g., "us-central1")
            endpoint_id: Vertex AI endpoint ID
            **model_params: Model parameters (temperature, max_tokens, etc.)
        """
        self.endpoint = aiplatform.Endpoint(
            endpoint_name=f"projects/{project}/locations/{location}/endpoints/{endpoint_id}"
        )
        self.model_params = model_params

    def invoke(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:
        """
        Invoke the model with a prompt and optional parameter overrides.

        Args:
            prompt: The input text prompt
            **kwargs: Optional parameter overrides for this specific call

        Returns:
            The model's response as a string
        """
        # Merge default parameters with any call-specific overrides
        parameters = {**self.model_params, **kwargs}

        instance = {"inputs": prompt, "parameters": parameters}

        response = self.endpoint.predict([instance])
        return response.predictions[0]

### Build the LangGraph agent

#### Initialize Vertex AI components

Initialize the Google's embedding model and the LLM model hosted on the Vertex AI Endpoint.


In [None]:
embeddings = VertexAIEmbeddings(model_name="text-embedding-005", project=PROJECT_ID)

llm = CustomVertexAIModel(
    endpoint_id=endpoint.name,
    temperature=1.0,
    project=PROJECT_ID,
    location=LOCATION,
)

#### Define agent components

According to LangGraph documentation, define an agent state and the following functions to build the agent:

1. `create_vectorstore_from_urls`: Loads web pages, splits them into chunks, and creates a searchable vector database using FAISS embeddings.
2. `retrieve`: Finds the 3 most similar document chunks from the vector store based on the user's query and adds them to the state context.
3. `generate_response`: Takes the retrieved context and query, sends them to the LLM for processing, and updates the state with the response and conversation history.
4. `should_rewrite`: Checks if the generated response is in proper SQL format by looking for SQL keywords.
5. `rewrite_response`: Asks the LLM to reformat the response into a proper SQL query with comments and proper syntax

In [None]:
class AgentState(TypedDict):
    query: str
    messages: list[BaseMessage]
    context: str
    response: str
    chat_history: list[BaseMessage]

In [None]:
def create_vectorstore_from_urls(urls: list[str]) -> FAISS:
    """Create a FAISS vectorstore from webpage contents"""
    loader = WebBaseLoader(urls)
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    splits = text_splitter.split_documents(documents)
    return FAISS.from_documents(splits, embeddings)


def retrieve(state: AgentState, vectorstore: FAISS) -> AgentState:
    """Retrieve relevant documents"""
    docs = vectorstore.similarity_search(state["query"], k=3)
    state["context"] = "\n".join(doc.page_content for doc in docs)
    return state


def generate_response(state: AgentState) -> AgentState:
    """Generate response using the Model Garden LLM"""
    prompt = f"""Context: {state["context"]}
                 Question: {state["query"]}
                 Please provide a helpful response based on the context above."""

    response = llm.invoke(prompt)

    state["response"] = response
    state["messages"].append(AIMessage(content=response))
    state["chat_history"].extend(
        [HumanMessage(content=state["query"]), AIMessage(content=response)]
    )
    return state


def should_rewrite(state: AgentState) -> AgentState:
    """Decide if the response needs rewriting"""
    # First check if query is SQL-related
    sql_keywords = [
        "sql",
        "query",
        "select",
        "table",
        "database",
        "bigquery",
        "join",
        "where",
    ]
    query_is_sql = any(keyword in state["query"].lower() for keyword in sql_keywords)

    # Only check SQL formatting if the query was SQL-related
    if query_is_sql:
        response = state["response"].lower()
        needs_rewrite = (
            not response.strip().startswith("select")
            and not response.strip().startswith("with")
            and not response.strip().startswith("create")
            and not response.strip().startswith("/*")
            and "select" not in response[:100]
        )
        state["next"] = "rewrite" if needs_rewrite else "end"
    else:
        state["next"] = "end"

    return state


def rewrite_response(state: AgentState) -> AgentState:
    """
    Rewrite the response to ensure it's in proper SQL format
    """
    prompt = f"""
    Original question: {state["query"]}
    Previous response: {state["response"]}

    Rewrite the above as a proper SQL query following these rules:
    - Start with SQL keywords (SELECT, WITH, CREATE, etc.)
    - Include comments explaining the logic
    - Format the query properly
    - Use BigQuery SQL syntax
    """

    new_response = llm.invoke(prompt)
    state["response"] = new_response
    state["messages"][-1] = AIMessage(content=new_response)  # Replace last message
    state["chat_history"][-1] = AIMessage(
        content=new_response
    )  # Replace last history item

    return state

#### Assemble the agent

Create a simple agentic RAG that first builds a searchable database from URLs, then sets up a sequence of steps (retrieve ‚Üí generate ‚Üí check for rewrite ‚Üí either end or rewrite and loop back) to process queries and generate SQL responses.

In [None]:
def create_rag_agent(urls: list[str]) -> Any:
    """Create the RAG agent workflow"""
    vectorstore = create_vectorstore_from_urls(urls)

    workflow = StateGraph(AgentState)

    # Add nodes
    workflow.add_node("retrieve", lambda s: retrieve(s, vectorstore))
    workflow.add_node("generate", generate_response)
    workflow.add_node("should_rewrite", should_rewrite)
    workflow.add_node("rewrite", rewrite_response)

    # Add edges
    workflow.add_edge("retrieve", "generate")
    workflow.add_edge("generate", "should_rewrite")

    # Add conditional edges
    workflow.add_conditional_edges(
        "should_rewrite", lambda x: x["next"], {"rewrite": "rewrite", "end": END}
    )

    workflow.add_edge("rewrite", "retrieve")

    workflow.set_entry_point("retrieve")

    return workflow.compile()

#### Initialize the agent

Initialize the agent by passing the BigQuery documentation to ground the agent.

In [None]:
urls = [
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/dml-syntax",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/arrays",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate_functions",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/subqueries",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/joins",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/operators",
    "https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference",
]


agent = create_rag_agent(urls)

#### Visualize the agent

Plot the agentic workflow.


In [None]:
display(
    Image(
        agent.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

### Query the agent

Use the agent to answer SQL code generation user requests.

In [None]:
def query_rag_agent(agent, query: str) -> dict[str, Any]:
    """Query the RAG agent"""
    state = {
        "query": query,
        "messages": [],
        "context": "",
        "response": "",
        "chat_history": [],
        "next": None,  # Add next step field
    }
    return agent.invoke(state)


questions = [
    "Write a SQL query to calculate the total sales per month and include a 3-month moving average",
    "Create a query that finds the top 5 customers by revenue in each region, including their total spend and number of orders",
    "Write a SQL query to analyze user engagement: calculate daily active users (DAU), weekly active users (WAU), and the DAU/WAU ratio for the last 30 days",
]

# Test each question
for question in questions:
    result = query_rag_agent(agent, question)
    print(f"\nQuestion: {question}")
    print("\nResponse:")
    print(result["response"])
    print("\n" + "=" * 80)

## Cleaning up

In [None]:
delete_endpoint = False
delete_model = False
delete_artifact_registry = False
delete_tutorial_folder = False

if delete_endpoint:
    deployed_model.undeploy_all()
    deployed_model.delete()

if delete_model:
    delete_model.delete()

if delete_artifact_registry:
    ! gcloud artifacts repositories delete $REPOSITORY_NAME \
          --repository-format=docker \
          --location=$LOCATION \
          --project=$PROJECT_ID

if delete_tutorial_folder:
    import shutil

    shutil.rmtree(tutorial_folder)