gemini/sample-apps/llamaindex-rag/backend/rag/async_extensions.py (144 lines of code) (raw):

"""Extensions to Llamaindex Base classes to allow for asynchronous execution""" from collections.abc import Sequence import logging from llama_index.core.base.response.schema import RESPONSE_TYPE from llama_index.core.callbacks import CallbackManager from llama_index.core.indices.query.query_transform.base import BaseQueryTransform from llama_index.core.prompts import BasePromptTemplate from llama_index.core.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.core.prompts.mixin import PromptDictType, PromptMixinType from llama_index.core.query_engine import BaseQueryEngine, RetrieverQueryEngine from llama_index.core.schema import NodeWithScore, QueryBundle, QueryType from llama_index.core.service_context_elements.llm_predictor import LLMPredictorType from llama_index.core.settings import Settings from pydantic import Field logging.basicConfig(level=logging.INFO) # Set the desired logging level logger = logging.getLogger(__name__) class AsyncTransformQueryEngine(BaseQueryEngine): """Transform query engine. Applies a query transform to a query bundle before passing it to a query engine. Args: query_engine (BaseQueryEngine): A query engine object. query_transform (BaseQueryTransform): A query transform object. transform_metadata (Optional[dict]): metadata to pass to the query transform. callback_manager (Optional[CallbackManager]): A callback manager. """ callback_manager: CallbackManager = Field( default_factory=lambda: CallbackManager([]), exclude=True ) def __init__( self, query_engine: BaseQueryEngine, query_transform: BaseQueryTransform, transform_metadata: dict | None = None, callback_manager: CallbackManager | None = None, ) -> None: self._query_engine = query_engine self._query_transform = query_transform self._transform_metadata = transform_metadata super().__init__(callback_manager) def _get_prompt_modules(self) -> PromptMixinType: """Get prompt sub-modules.""" return { "query_transform": self._query_transform, "query_engine": self._query_engine, } async def aretrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: query_bundle = await self._query_transform._arun( query_bundle, metadata=self._transform_metadata ) return await self._query_engine.aretrieve(query_bundle) def synthesize( self, query_bundle: QueryBundle, nodes: list[NodeWithScore], additional_source_nodes: Sequence[NodeWithScore] | None = None, ) -> RESPONSE_TYPE: query_bundle = self._query_transform.run( query_bundle, metadata=self._transform_metadata ) return self._query_engine.synthesize( query_bundle=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) async def arun( self, query_bundle_or_str: QueryType, metadata: dict | None = None, ) -> QueryBundle: """Run query transform.""" metadata = metadata or {} if isinstance(query_bundle_or_str, str): query_bundle = QueryBundle( query_str=query_bundle_or_str, custom_embedding_strs=[query_bundle_or_str], ) else: query_bundle = query_bundle_or_str return await self._query_transform._arun(query_bundle, metadata=metadata) async def asynthesize( self, query_bundle: QueryBundle, nodes: list[NodeWithScore], additional_source_nodes: Sequence[NodeWithScore] | None = None, ) -> RESPONSE_TYPE: query_bundle = await self._query_transform._arun( query_bundle, metadata=self._transform_metadata ) return await self._query_engine.asynthesize( query_bundle=query_bundle, nodes=nodes, additional_source_nodes=additional_source_nodes, ) def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" query_bundle = self._query_transform.run( query_bundle, metadata=self._transform_metadata ) return self._query_engine.query(query_bundle) async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: """Answer a query.""" query_bundle = await self._query_transform._arun( query_bundle, metadata=self._transform_metadata ) return await self._query_engine.aquery(query_bundle) class AsyncHyDEQueryTransform(BaseQueryTransform): """Hypothetical Document Embeddings (HyDE) query transform. It uses an LLM to generate hypothetical answer(s) to a given query, and use the resulting documents as embedding strings. As described in `[Precise Zero-Shot Dense Retrieval without Relevance Labels] (https://arxiv.org/abs/2212.10496)` """ def __init__( self, llm: LLMPredictorType | None = None, hyde_prompt: BasePromptTemplate | None = None, include_original: bool = True, ) -> None: """Initialize HyDEQueryTransform. Args: llm_predictor (Optional[LLM]): LLM for generating hypothetical documents hyde_prompt (Optional[BasePromptTemplate]): Custom prompt for HyDE include_original (bool): Whether to include original query string as one of the embedding strings """ super().__init__() self._llm = llm or Settings.llm self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT self._include_original = include_original def _get_prompts(self) -> PromptDictType: """Get prompts.""" return {"hyde_prompt": self._hyde_prompt} def _update_prompts(self, prompts: PromptDictType) -> None: """Update prompts.""" if "hyde_prompt" in prompts: self._hyde_prompt = prompts["hyde_prompt"] def _run(self, query_bundle: QueryBundle, metadata: dict) -> QueryBundle: """Run query transform.""" # TODO: support generating multiple hypothetical docs query_str = query_bundle.query_str hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str) embedding_strs = [hypothetical_doc] if self._include_original: embedding_strs.extend(query_bundle.embedding_strs) return QueryBundle( query_str=query_str, custom_embedding_strs=embedding_strs, ) async def _arun(self, query_bundle: QueryBundle) -> QueryBundle: """Run query transform.""" # TODO: support generating multiple hypothetical docs query_str = query_bundle.query_str hypothetical_doc = await self._llm.apredict( self._hyde_prompt, context_str=query_str ) embedding_strs = [hypothetical_doc] if self._include_original: embedding_strs.extend(query_bundle.embedding_strs) return QueryBundle( query_str=query_str, custom_embedding_strs=embedding_strs, ) class AsyncRetrieverQueryEngine(RetrieverQueryEngine): """Async Extension of the ReterieverQueryEngine to allow for asynchronous post-processing""" async def _apply_node_postprocessors( self, nodes: list[NodeWithScore], query_bundle: QueryBundle ) -> list[NodeWithScore]: """Apply node postprocessors.""" for node_postprocessor in self._node_postprocessors: nodes = await node_postprocessor.postprocess_nodes( nodes, query_bundle=query_bundle ) return nodes async def aretrieve(self, query_bundle: QueryBundle) -> list[NodeWithScore]: """Retrieve nodes""" nodes = await self._retriever.aretrieve(query_bundle) num_nodes = len(nodes) logger.info(f"Total nodes retrieved {num_nodes}") return await self._apply_node_postprocessors(nodes, query_bundle=query_bundle)