ai-ml/rag-architectures/serving/backend/app.py (184 lines of code) (raw):

import os from typing import List, Dict, Any, Optional import uvicorn from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from google import genai from google.genai.types import HttpOptions from google.genai.types import EmbedContentConfig from google.cloud import aiplatform import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="Quantum Chatbot Backend") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Update with your frontend origin in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configuration model class AppConfig: def __init__(self): # GCP Project settings self.project_id = os.environ.get("PROJECT_ID") self.location = os.environ.get("GCP_REGION", "us-central1") logger.info( f"Initializing with project_id: {self.project_id}, location: {self.location}" ) # Vertex AI Vector Search settings self.index_endpoint_id = os.environ.get("VECTOR_SEARCH_INDEX_ENDPOINT_NAME") self.index_id = os.environ.get("VECTOR_SEARCH_INDEX_ID") self.deployed_index_id = os.environ.get("VECTOR_SEARCH_DEPLOYED_INDEX_ID") logger.info( f"Vector Search settings - endpoint: {self.index_endpoint_id}, index: {self.index_id}, deployed: {self.deployed_index_id}" ) # Gemini API settings self.model_name = os.environ.get("GEMINI_MODEL_NAME", "gemini-flash-2.0") self.embedding_model = os.environ.get("EMBEDDING_MODEL", "text-embedding-005") # RAG settings self.num_neighbors = int(os.environ.get("NUM_NEIGHBORS", "5")) self.max_context_length = int(os.environ.get("MAX_CONTEXT_LENGTH", "8000")) # Initialize clients self.initialize_clients() def initialize_clients(self): """Initialize Google Cloud clients""" try: # Initialize Vertex AI for Vector Search aiplatform.init(project=self.project_id, location=self.location) # Initialize Vector Search Index and Endpoint self.index = aiplatform.MatchingEngineIndex(self.index_id) logger.info(f"Successfully initialized index: {self.index_id}") self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint( index_endpoint_name=self.index_endpoint_id ) logger.info( f"Successfully initialized index endpoint: {self.index_endpoint_id}" ) # Initialize Vertex AI for Gemini os.environ["GOOGLE_CLOUD_PROJECT"] = self.project_id os.environ["GOOGLE_CLOUD_LOCATION"] = self.location os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "true" self.genai_client = genai.Client() logger.info("Successfully initialized Gemini client") except Exception as e: logger.error(f"Error initializing clients: {str(e)}") raise # Global config instance def get_config(): return app.state.config # Request/Response models class PromptRequest(BaseModel): prompt: str num_neighbors: Optional[int] = None use_context: Optional[bool] = True class PromptResponse(BaseModel): response: str context_used: bool neighbors_found: int @app.on_event("startup") async def startup_event(): """Initialize app state on startup""" logger.info("Initializing app state...") try: app.state.config = AppConfig() logger.info("App state initialized successfully") except Exception as e: logger.error(f"Failed to initialize app state: {str(e)}") raise @app.get("/") async def root(): """Health check endpoint""" return {"status": "healthy", "service": "Quantum Chatbot Backend"} @app.post("/prompt", response_model=PromptResponse) async def process_prompt( request: PromptRequest, config: AppConfig = Depends(get_config) ): """ Process a user prompt through the RAG pipeline: 1. Retrieve context from Vector Search 2. Augment the prompt with retrieved context 3. Generate a response using Gemini 2.0 Flash """ try: logger.info(f"Processing prompt: {request.prompt[:50]}...") # Determine number of neighbors to retrieve num_neighbors = request.num_neighbors or config.num_neighbors # 1. Retrieve context from Vector Search, 2. Augment prompt with context if request.use_context: logger.info("✅ Use Context was True - Retrieving context") neighbors = prompt_vector_search(request.prompt, num_neighbors, config) # 2. Augment prompt with context augmented_prompt = create_augmented_prompt(request.prompt, neighbors) else: logger.info("❌ Use Context was False - Skipping context retrieval") augmented_prompt = request.prompt neighbors = [] # 3. Generate response with Gemini response = prompt_gemini(augmented_prompt, config) logger.info("Generated response from Gemini") return PromptResponse( response=response, context_used=len(neighbors) > 0, neighbors_found=len(neighbors), ) except Exception as e: logger.error(f"Error processing prompt: {str(e)}") raise HTTPException( status_code=500, detail=f"Error processing prompt: {str(e)}" ) def extract_id_and_text(neighbor): """ Extract ID and text from a Vertex AI Vector Search "MatchNeighbor" object """ id_value = neighbor.id text_value = None if hasattr(neighbor, "restricts") and neighbor.restricts: for restrict in neighbor.restricts: if hasattr(restrict, "name") and restrict.name == "text": if hasattr(restrict, "allow_tokens") and restrict.allow_tokens: text_value = restrict.allow_tokens[0] break return {"id": id_value, "text": text_value} def prompt_vector_search( prompt: str, num_neighbors: int, config: AppConfig ) -> List[Dict[str, Any]]: """Prompt Vector Search to find relevant documents""" try: logger.info("Creating embeddings for query") # Convert prompt to embeddings response = config.genai_client.models.embed_content( model=config.embedding_model, contents=[prompt], config=EmbedContentConfig( task_type="RETRIEVAL_QUERY", output_dimensionality=768, ), ) query_embedding = response.embeddings[0].values logger.info("Query embeddings: " + str(query_embedding[:10]) + "...") # Get nearest neighbors from Vector Search logger.info( f"❓ Querying Vector Search with deployed_index_id: {config.deployed_index_id}" ) neighbors = config.index_endpoint.find_neighbors( deployed_index_id=config.deployed_index_id, queries=[query_embedding], num_neighbors=num_neighbors, return_full_datapoint=True, # Make sure this is True ) logger.info(f"Vector Search returned {len(neighbors[0])} matches") return neighbors[0] except Exception as e: logger.error(f"Error querying Vector Search: {str(e)}") raise def extract_id_and_text(neighbor): """ Extract ID and text from a Vertex AI Vector Search "MatchNeighbor" object """ id_value = neighbor.id text_value = None if hasattr(neighbor, "restricts") and neighbor.restricts: for restrict in neighbor.restricts: if hasattr(restrict, "name") and restrict.name == "text": if hasattr(restrict, "allow_tokens") and restrict.allow_tokens: text_value = restrict.allow_tokens[0] break return {"id": id_value, "text": text_value} def create_augmented_prompt(prompt: str, neighbors: list) -> str: """Create a prompt augmented with retrieved context""" if not neighbors or len(neighbors) == 0: return prompt print("Got # neighbors: " + str(len(neighbors))) augment = [] for n in neighbors: result = extract_id_and_text(n) print(f"ID: {result['id']}") print(f"Text: {result['text']}") augment.append(result["text"]) context = "\n".join(augment) final_prompt = f"""You are an expert chatbot in quantum computing. Use the provided up-to-date information to answer the user's question. Only respond on topics related to quantum computing. Answer in 3 sentences or less! Context: {context} User Prompt: {prompt} """ print("⭐ Augmented Prompt: ", final_prompt) return final_prompt def prompt_gemini(prompt: str, config: AppConfig) -> str: """Prompt Gemini model with the augmented prompt https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2#google-gen """ try: logger.info(f"Calling Gemini with model: {config.model_name}") client = genai.Client(http_options=HttpOptions(api_version="v1")) response = client.models.generate_content( model="gemini-2.0-flash-001", # Using hardcoded model name for consistency contents=[prompt], ) return response.text except Exception as e: logger.error(f"Error prompting Gemini: {str(e)}") raise if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)