gemini/sample-apps/llamaindex-rag/backend/rag/node_reranker.py (173 lines of code) (raw):
"""Node Re-ranker class for async execution"""
from collections.abc import Callable
import logging
import google.auth
import google.auth.transport.requests
from llama_index.core import QueryBundle, Settings
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.core.service_context import ServiceContext
from llama_index.core.settings import llm_from_settings_or_context
from llama_index.llms.vertex import Vertex
import requests
logging.basicConfig(level=logging.INFO) # Set the desired logging level
logger = logging.getLogger(__name__)
# Initialize the LLM and set it in the Settings
llm = Vertex(model="gemini-2.0-flash", temperature=0.0)
Settings.llm = llm
def authenticate_google():
"""Authenticate using Google credentials and return the access token."""
credentials, project_id = google.auth.default(
quota_project_id="pr-sbx-vertex-genai"
)
auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)
return credentials.token
def call_reranker(query, records, google_token):
"""Calls the reranker API with the given query and records.
Args:
query: The search query.
records: A list of dictionaries, where each dictionary represents a record
with "id", "title", and "content" fields.
Returns:
The API response as a dictionary.
"""
# Replace 'your-project-id' with your actual Google Cloud project ID
project_id = "pr-sbx-vertex-genai"
model_name = "semantic-ranker-512@latest"
url = f"https://discoveryengine.googleapis.com/v1alpha/projects/{project_id}/locations/global/rankingConfigs/default_ranking_config:rank"
headers = {
"Authorization": "Bearer " + google_token,
"Content-Type": "application/json",
"X-Goog-User-Project": project_id,
}
data = {
"model": model_name,
"query": query,
"records": records,
}
response = requests.post(url, headers=headers, json=data)
print(response)
response.raise_for_status() # Raise an error if the request failed
return response.json()
class GoogleReRankerSecretSauce(BaseNodePostprocessor):
def _postprocess_nodes(
self, nodes: list[NodeWithScore], query_bundle: QueryBundle | None
) -> list[NodeWithScore]:
google_token = authenticate_google()
records = []
for node_wscore in nodes:
records.append(
{
"id": node_wscore.node.id_,
"title": node_wscore.node.metadata["title"],
"content": node_wscore.node.text,
}
)
response_json = call_reranker(query_bundle.query_str, records, google_token)
records = response_json["records"]
new_nodes_wscores = []
for r in records:
node = TextNode(id_=r["id"], text=r["content"])
node_wscore = NodeWithScore(node=node, score=r["score"])
new_nodes_wscores.append(node_wscore)
return sorted(new_nodes_wscores, key=lambda x: x.score or 0.0, reverse=True)
class CustomLLMRerank(BaseNodePostprocessor):
"""LLM-based reranker."""
top_n: int = Field(description="Top N nodes to return.")
choice_select_prompt: BasePromptTemplate = Field(
description="Choice select prompt."
)
choice_batch_size: int = Field(description="Batch size for choice select.")
llm: LLM = Field(description="The LLM to rerank with.")
_format_node_batch_fn: Callable = PrivateAttr()
_parse_choice_select_answer_fn: Callable = PrivateAttr()
def __init__(
self,
llm: LLM | None = None,
choice_select_prompt: BasePromptTemplate | None = None,
choice_batch_size: int = 10,
format_node_batch_fn: Callable | None = None,
parse_choice_select_answer_fn: Callable | None = None,
service_context: ServiceContext | None = None,
top_n: int = 10,
) -> None:
choice_select_prompt = choice_select_prompt or DEFAULT_CHOICE_SELECT_PROMPT
llm = llm or llm_from_settings_or_context(Settings, service_context)
self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)
super().__init__(
llm=llm,
choice_select_prompt=choice_select_prompt,
choice_batch_size=choice_batch_size,
service_context=service_context,
top_n=top_n,
)
def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {"choice_select_prompt": self.choice_select_prompt}
def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "choice_select_prompt" in prompts:
self.choice_select_prompt = prompts["choice_select_prompt"]
@classmethod
def class_name(cls) -> str:
return "LLMRerank"
async def postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None,
query_str: str | None = None,
) -> list[NodeWithScore]:
"""Postprocess nodes."""
if query_str is not None and query_bundle is not None:
raise ValueError("Cannot specify both query_str and query_bundle")
elif query_str is not None:
query_bundle = QueryBundle(query_str)
else:
pass
return await self._postprocess_nodes(nodes, query_bundle)
async def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None,
) -> list[NodeWithScore]:
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []
initial_results: list[NodeWithScore] = []
for idx in range(0, len(nodes), self.choice_batch_size):
nodes_batch = [
node.node for node in nodes[idx : idx + self.choice_batch_size]
]
query_str = query_bundle.query_str
fmt_batch_str = self._format_node_batch_fn(nodes_batch)
# call each batch independently
raw_response = await self.llm.apredict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
logging.info(raw_response)
try:
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(nodes_batch)
)
# Try again
except IndexError:
raw_response = await self.llm.apredict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)
raw_choices, relevances = self._parse_choice_select_answer_fn(
raw_response, len(nodes_batch)
)
choice_idxs = [int(choice) - 1 for choice in raw_choices]
choice_nodes = [nodes_batch[idx] for idx in choice_idxs]
relevances = relevances or [1.0 for _ in choice_nodes]
initial_results.extend(
[
NodeWithScore(node=node, score=relevance)
for node, relevance in zip(choice_nodes, relevances)
]
)
return sorted(initial_results, key=lambda x: x.score or 0.0, reverse=True)[
: self.top_n
]