gemini/sample-apps/llamaindex-rag/backend/rag/index_manager.py (307 lines of code) (raw):
"""Main state management class for indices and prompts for
experimentation UI"""
import logging
import Stemmer
from backend.rag.async_extensions import (
AsyncHyDEQueryTransform,
AsyncRetrieverQueryEngine,
AsyncTransformQueryEngine,
)
from backend.rag.claude_vertex import ClaudeVertexLLM
from backend.rag.node_reranker import CustomLLMRerank
from backend.rag.parent_retriever import ParentRetriever
from backend.rag.prompts import Prompts
from backend.rag.qa_followup_retriever import QAFollowupRetriever, QARetriever
from google.cloud import aiplatform
from llama_index.core import (
PromptTemplate,
Settings,
StorageContext,
VectorStoreIndex,
get_response_synthesizer,
)
from llama_index.core.agent import ReActAgent
from llama_index.core.retrievers import AutoMergingRetriever, QueryFusionRetriever
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.embeddings.vertex import VertexTextEmbedding
from llama_index.llms.vertex import Vertex
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.storage.docstore.firestore import FirestoreDocumentStore
from llama_index.vector_stores.vertexaivectorsearch import VertexAIVectorStore
logging.basicConfig(level=logging.INFO) # Set the desired logging level
logger = logging.getLogger(__name__)
class IndexManager:
"""
This class manages state for vector indexes,
docstores, query engines and chat engines
across the app's lifecycle (e.g. through UI manipulations).
The index_manager (instantiated) will be injected into all API calls
that need to access its state or manipulate its state.
This includes:
- Switching out vector indices or docstores
- Changing retrieval parameters (e.g. temperature, llm model, etc.)
"""
def __init__(
self,
project_id: str,
location: str,
base_index_name: str,
base_endpoint_name: str,
qa_index_name: str | None,
qa_endpoint_name: str | None,
embeddings_model_name: str,
firestore_db_name: str | None,
firestore_namespace: str | None,
vs_bucket_name: str,
):
self.project_id = project_id
self.location = location
self.embeddings_model_name = embeddings_model_name
self.base_index_name = base_index_name
self.base_endpoint_name = base_endpoint_name
self.qa_index_name = qa_index_name
self.qa_endpoint_name = qa_endpoint_name
self.firestore_db_name = firestore_db_name
self.firestore_namespace = firestore_namespace
self.vs_bucket_name = vs_bucket_name
self.embed_model = VertexTextEmbedding(
model_name=self.embeddings_model_name,
project=self.project_id,
location=self.location,
)
self.base_index = self.get_vector_index(
index_name=self.base_index_name,
endpoint_name=self.base_endpoint_name,
firestore_db_name=self.firestore_db_name,
firestore_namespace=self.firestore_namespace,
)
if self.qa_endpoint_name and self.qa_index_name:
self.qa_index = self.get_vector_index(
index_name=self.qa_index_name,
endpoint_name=self.qa_endpoint_name,
firestore_db_name=self.firestore_db_name,
firestore_namespace=self.firestore_namespace,
)
else:
self.qa_index = None
def get_current_index_info(self) -> dict:
"""Return the indices currently being used"""
return {
"base_index_name": self.base_index_name,
"base_endpoint_name": self.base_endpoint_name,
"qa_index_name": self.qa_index_name,
"qa_endpoint_name": self.qa_endpoint_name,
"firestore_db_name": self.firestore_db_name,
"firestore_namespace": self.firestore_namespace,
}
def get_vertex_llm(
self, llm_name: str, temperature: float, system_prompt: str
) -> Vertex | ClaudeVertexLLM:
"""Return the LLM currently being used"""
if "gemini" in llm_name:
llm = Vertex(
model=llm_name,
max_tokens=3000,
temperature=temperature,
system_prompt=system_prompt,
)
elif "claude" in llm_name:
llm = ClaudeVertexLLM(
project_id=self.project_id,
region="us-east5",
model_name="claude-3-5-sonnet@20240620",
max_tokens=3000,
system_prompt=system_prompt,
)
Settings.llm = llm
return llm
def set_current_indexes(
self,
base_index_name,
base_endpoint_name,
qa_index_name: str | None,
qa_endpoint_name: str | None,
firestore_db_name: str | None,
firestore_namespace: str | None,
) -> None:
"""Set the current indices to be used for the RAG"""
self.base_index_name = base_index_name
self.base_endpoint_name = base_endpoint_name
self.qa_index_name = qa_index_name
self.qa_endpoint_name = qa_endpoint_name
self.firestore_db_name = firestore_db_name
self.firestore_namespace = firestore_namespace
self.base_index = self.get_vector_index(
index_name=self.base_index_name,
endpoint_name=self.base_endpoint_name,
firestore_db_name=self.firestore_db_name,
firestore_namespace=self.firestore_namespace,
)
if self.qa_endpoint_name and self.qa_index_name:
self.qa_index = self.get_vector_index(
index_name=self.qa_index_name,
endpoint_name=self.qa_endpoint_name,
firestore_db_name=self.firestore_db_name,
firestore_namespace=self.firestore_namespace,
)
else:
self.qa_index = None
def get_vector_index(
self,
index_name: str,
endpoint_name: str,
firestore_db_name: str | None,
firestore_namespace: str | None,
) -> VectorStoreIndex:
"""
Returns a llamaindex VectorStoreIndex object which contains a storage context,
with an accompanying local document store from Google Cloud Storage.
"""
# Initialize Vertex AI
aiplatform.init(project=self.project_id, location=self.location)
# Get the Vector Search index
indexes = aiplatform.MatchingEngineIndex.list(
filter=f'display_name="{index_name}"'
)
if not indexes:
raise ValueError(f"No index found with display name: {index_name}")
vs_index = indexes[0]
# Get the Vector Search endpoint
endpoints = aiplatform.MatchingEngineIndexEndpoint.list(
filter=f'display_name="{endpoint_name}"'
)
if not endpoints:
raise ValueError(f"No endpoint found with display name: {endpoint_name}")
vs_endpoint = endpoints[0]
# Create the vector store
vector_store = VertexAIVectorStore(
project_id=self.project_id,
region=self.location,
index_id=vs_index.resource_name.split("/")[-1],
endpoint_id=vs_endpoint.resource_name.split("/")[-1],
gcs_bucket_name=self.vs_bucket_name,
)
if firestore_db_name and firestore_namespace:
docstore = FirestoreDocumentStore.from_database(
project=self.project_id,
database=firestore_db_name,
namespace=firestore_namespace,
)
else:
docstore = None
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=vector_store, docstore=docstore
)
# Create and return the index
vector_store_index = VectorStoreIndex(
nodes=[], storage_context=storage_context, embed_model=self.embed_model
)
return vector_store_index
def get_query_engine(
self,
prompts: Prompts,
llm_name: str = "gemini-2.0-flash",
temperature: float = 0.0,
similarity_top_k: int = 5,
retrieval_strategy: str = "auto_merging",
use_hyde: bool = True,
use_refine: bool = True,
use_node_rerank: bool = False,
qa_followup: bool = True,
hybrid_retrieval: bool = True,
) -> AsyncRetrieverQueryEngine:
"""
Creates a llamaindex QueryEngine given a
VectorStoreIndex and hyperparameters
"""
llm = self.get_vertex_llm(
llm_name=llm_name,
temperature=temperature,
system_prompt=Prompts.system_prompt,
)
Settings.llm = llm
qa_prompt = PromptTemplate(prompts.qa_prompt_tmpl)
refine_prompt = PromptTemplate(prompts.refine_prompt_tmpl)
if use_refine:
synth = get_response_synthesizer(
text_qa_template=qa_prompt,
refine_template=refine_prompt,
response_mode="compact",
use_async=True,
)
else:
synth = get_response_synthesizer(
text_qa_template=qa_prompt, response_mode="compact", use_async=True
)
base_retriever = self.base_index.as_retriever(similarity_top_k=similarity_top_k)
if self.qa_index:
qa_vector_retriever = self.qa_index.as_retriever(
similarity_top_k=similarity_top_k
)
else:
qa_vector_retriever = None
query_engine = None # Default initialization
# Choose between retrieval strategies and configurations.
if retrieval_strategy == "auto_merging":
logger.info(self.base_index.storage_context.docstore)
retriever = AutoMergingRetriever(
base_retriever, self.base_index.storage_context, verbose=True
)
elif retrieval_strategy == "parent":
retriever = ParentRetriever(
base_retriever, docstore=self.base_index.docstore
)
elif retrieval_strategy == "baseline":
retriever = base_retriever
if qa_followup:
qa_retriever = QARetriever(
qa_vector_retriever=qa_vector_retriever, docstore=self.qa_index.docstore
)
retriever = QAFollowupRetriever(
qa_retriever=qa_retriever, base_retriever=retriever
)
if hybrid_retrieval:
bm25_retriever = BM25Retriever.from_defaults(
docstore=self.base_index.docstore,
similarity_top_k=similarity_top_k,
stemmer=Stemmer.Stemmer("english"),
language="english",
)
retriever = QueryFusionRetriever(
[retriever, bm25_retriever],
similarity_top_k=similarity_top_k,
num_queries=1, # set this to 1 to disable query generation
mode="reciprocal_rerank",
use_async=True,
verbose=True,
# query_gen_prompt="...", # we could override the
# query generation prompt here
)
if use_node_rerank:
reranker_llm = Vertex(
model="gemini-2.0-flash",
max_tokens=8192,
temperature=temperature,
system_prompt=prompts.system_prompt,
)
choice_select_prompt = PromptTemplate(prompts.choice_select_prompt_tmpl)
llm_reranker = CustomLLMRerank(
choice_batch_size=10,
top_n=5,
choice_select_prompt=choice_select_prompt,
llm=reranker_llm,
)
else:
llm_reranker = None
query_engine = AsyncRetrieverQueryEngine.from_args(
retriever,
response_synthesizer=synth,
node_postprocessors=[llm_reranker] if llm_reranker else None,
)
if use_hyde:
hyde_prompt = PromptTemplate(prompts.hyde_prompt_tmpl)
hyde = AsyncHyDEQueryTransform(
include_original=True, hyde_prompt=hyde_prompt
)
query_engine = AsyncTransformQueryEngine(
query_engine=query_engine, query_transform=hyde
)
self.query_engine = query_engine
return query_engine
def get_react_agent(
self,
prompts: Prompts,
llm_name: str = "gemini-2.0-flash",
temperature: float = 0.2,
) -> ReActAgent:
"""
Creates a ReAct agent from a given QueryEngine
"""
query_engine_tools = [
QueryEngineTool(
query_engine=self.query_engine,
metadata=ToolMetadata(
name="google_financials",
description=(
"Provides information about Google financials. "
"Use a detailed plain text question as input to the tool."
),
),
)
]
llm = self.get_vertex_llm(
llm_name=llm_name,
temperature=temperature,
system_prompt=prompts.system_prompt,
)
Settings.llm = llm
agent = ReActAgent.from_tools(
query_engine_tools, llm=llm, verbose=True, context=prompts.system_prompt
)
return agent