datastore/providers/llama_datastore.py (160 lines of code) (raw):

import json import os from typing import Dict, List, Optional, Type from loguru import logger from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, DocumentChunkWithScore, DocumentMetadataFilter, Query, QueryResult, QueryWithEmbedding, ) from llama_index.indices.base import BaseGPTIndex from llama_index.indices.vector_store.base import GPTVectorStoreIndex from llama_index.indices.query.schema import QueryBundle from llama_index.response.schema import Response from llama_index.data_structs.node_v2 import Node, DocumentRelationship, NodeWithScore from llama_index.indices.registry import INDEX_STRUCT_TYPE_TO_INDEX_CLASS from llama_index.data_structs.struct_type import IndexStructType from llama_index.indices.response.builder import ResponseMode INDEX_STRUCT_TYPE_STR = os.environ.get( "LLAMA_INDEX_TYPE", IndexStructType.SIMPLE_DICT.value ) INDEX_JSON_PATH = os.environ.get("LLAMA_INDEX_JSON_PATH", None) QUERY_KWARGS_JSON_PATH = os.environ.get("LLAMA_QUERY_KWARGS_JSON_PATH", None) RESPONSE_MODE = os.environ.get("LLAMA_RESPONSE_MODE", ResponseMode.NO_TEXT.value) EXTERNAL_VECTOR_STORE_INDEX_STRUCT_TYPES = [ IndexStructType.DICT, IndexStructType.WEAVIATE, IndexStructType.PINECONE, IndexStructType.QDRANT, IndexStructType.CHROMA, IndexStructType.VECTOR_STORE, ] def _create_or_load_index( index_type_str: Optional[str] = None, index_json_path: Optional[str] = None, index_type_to_index_cls: Optional[dict[str, Type[BaseGPTIndex]]] = None, ) -> BaseGPTIndex: """Create or load index from json path.""" index_json_path = index_json_path or INDEX_JSON_PATH index_type_to_index_cls = ( index_type_to_index_cls or INDEX_STRUCT_TYPE_TO_INDEX_CLASS ) index_type_str = index_type_str or INDEX_STRUCT_TYPE_STR index_type = IndexStructType(index_type_str) if index_type not in index_type_to_index_cls: raise ValueError(f"Unknown index type: {index_type}") if index_type in EXTERNAL_VECTOR_STORE_INDEX_STRUCT_TYPES: raise ValueError("Please use vector store directly.") index_cls = index_type_to_index_cls[index_type] if index_json_path is None: return index_cls(nodes=[]) # Create empty index else: return index_cls.load_from_disk(index_json_path) # Load index from disk def _create_or_load_query_kwargs( query_kwargs_json_path: Optional[str] = None, ) -> Optional[dict]: """Create or load query kwargs from json path.""" query_kwargs_json_path = query_kwargs_json_path or QUERY_KWARGS_JSON_PATH query_kargs: Optional[dict] = None if query_kwargs_json_path is not None: with open(INDEX_JSON_PATH, "r") as f: query_kargs = json.load(f) return query_kargs def _doc_chunk_to_node(doc_chunk: DocumentChunk, source_doc_id: str) -> Node: """Convert document chunk to Node""" return Node( doc_id=doc_chunk.id, text=doc_chunk.text, embedding=doc_chunk.embedding, extra_info=doc_chunk.metadata.dict(), relationships={DocumentRelationship.SOURCE: source_doc_id}, ) def _query_with_embedding_to_query_bundle(query: QueryWithEmbedding) -> QueryBundle: return QueryBundle( query_str=query.query, embedding=query.embedding, ) def _source_node_to_doc_chunk_with_score( node_with_score: NodeWithScore, ) -> DocumentChunkWithScore: node = node_with_score.node if node.extra_info is not None: metadata = DocumentChunkMetadata(**node.extra_info) else: metadata = DocumentChunkMetadata() return DocumentChunkWithScore( id=node.doc_id, text=node.text, score=node_with_score.score if node_with_score.score is not None else 1.0, metadata=metadata, ) def _response_to_query_result( response: Response, query: QueryWithEmbedding ) -> QueryResult: results = [ _source_node_to_doc_chunk_with_score(node) for node in response.source_nodes ] return QueryResult( query=query.query, results=results, ) class LlamaDataStore(DataStore): def __init__( self, index: Optional[BaseGPTIndex] = None, query_kwargs: Optional[dict] = None ): self._index = index or _create_or_load_index() self._query_kwargs = query_kwargs or _create_or_load_query_kwargs() async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a list of list of document chunks and inserts them into the database. Return a list of document ids. """ doc_ids = [] for doc_id, doc_chunks in chunks.items(): logger.debug(f"Upserting {doc_id} with {len(doc_chunks)} chunks") nodes = [ _doc_chunk_to_node(doc_chunk=doc_chunk, source_doc_id=doc_id) for doc_chunk in doc_chunks ] self._index.insert_nodes(nodes) doc_ids.append(doc_id) return doc_ids async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ query_result_all = [] for query in queries: if query.filter is not None: logger.warning("Filters are not supported yet, ignoring for now.") query_bundle = _query_with_embedding_to_query_bundle(query) # Setup query kwargs if self._query_kwargs is not None: query_kwargs = self._query_kwargs else: query_kwargs = {} # TODO: support top_k for other indices if isinstance(self._index, GPTVectorStoreIndex): query_kwargs["similarity_top_k"] = query.top_k response = await self._index.aquery( query_bundle, response_mode=RESPONSE_MODE, **query_kwargs ) query_result = _response_to_query_result(response, query) query_result_all.append(query_result) return query_result_all async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes vectors by ids, filter, or everything in the datastore. Returns whether the operation was successful. """ if delete_all: logger.warning("Delete all not supported yet.") return False if filter is not None: logger.warning("Filters are not supported yet.") return False if ids is not None: for id_ in ids: try: self._index.delete(id_) except NotImplementedError: # NOTE: some indices does not support delete yet. logger.warning(f"{type(self._index)} does not support delete yet.") return False return True