datastore/providers/pinecone_datastore.py (194 lines of code) (raw):

import os from typing import Any, Dict, List, Optional import pinecone from tenacity import retry, wait_random_exponential, stop_after_attempt import asyncio from loguru import logger from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, Source, ) from services.date import to_unix_timestamp # Read environment variables for Pinecone configuration PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY") PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT") PINECONE_INDEX = os.environ.get("PINECONE_INDEX") assert PINECONE_API_KEY is not None assert PINECONE_ENVIRONMENT is not None assert PINECONE_INDEX is not None # Initialize Pinecone with the API key and environment pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_ENVIRONMENT) # Set the batch size for upserting vectors to Pinecone UPSERT_BATCH_SIZE = 100 EMBEDDING_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256)) class PineconeDataStore(DataStore): def __init__(self): # Check if the index name is specified and exists in Pinecone if PINECONE_INDEX and PINECONE_INDEX not in pinecone.list_indexes(): # Get all fields in the metadata object in a list fields_to_index = list(DocumentChunkMetadata.__fields__.keys()) # Create a new index with the specified name, dimension, and metadata configuration try: logger.info( f"Creating index {PINECONE_INDEX} with metadata config {fields_to_index}" ) pinecone.create_index( PINECONE_INDEX, dimension=EMBEDDING_DIMENSION, metadata_config={"indexed": fields_to_index}, ) self.index = pinecone.Index(PINECONE_INDEX) logger.info(f"Index {PINECONE_INDEX} created successfully") except Exception as e: logger.error(f"Error creating index {PINECONE_INDEX}: {e}") raise e elif PINECONE_INDEX and PINECONE_INDEX in pinecone.list_indexes(): # Connect to an existing index with the specified name try: logger.info(f"Connecting to existing index {PINECONE_INDEX}") self.index = pinecone.Index(PINECONE_INDEX) logger.info(f"Connected to index {PINECONE_INDEX} successfully") except Exception as e: logger.error(f"Error connecting to index {PINECONE_INDEX}: {e}") raise e @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a dict from document id to list of document chunks and inserts them into the index. Return a list of document ids. """ # Initialize a list of ids to return doc_ids: List[str] = [] # Initialize a list of vectors to upsert vectors = [] # Loop through the dict items for doc_id, chunk_list in chunks.items(): # Append the id to the ids list doc_ids.append(doc_id) logger.info(f"Upserting document_id: {doc_id}") for chunk in chunk_list: # Create a vector tuple of (id, embedding, metadata) # Convert the metadata object to a dict with unix timestamps for dates pinecone_metadata = self._get_pinecone_metadata(chunk.metadata) # Add the text and document id to the metadata dict pinecone_metadata["text"] = chunk.text pinecone_metadata["document_id"] = doc_id vector = (chunk.id, chunk.embedding, pinecone_metadata) vectors.append(vector) # Split the vectors list into batches of the specified size batches = [ vectors[i : i + UPSERT_BATCH_SIZE] for i in range(0, len(vectors), UPSERT_BATCH_SIZE) ] # Upsert each batch to Pinecone for batch in batches: try: logger.info(f"Upserting batch of size {len(batch)}") self.index.upsert(vectors=batch) logger.info(f"Upserted batch successfully") except Exception as e: logger.error(f"Error upserting batch: {e}") raise e return doc_ids @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ # Define a helper coroutine that performs a single query and returns a QueryResult async def _single_query(query: QueryWithEmbedding) -> QueryResult: logger.debug(f"Query: {query.query}") # Convert the metadata filter object to a dict with pinecone filter expressions pinecone_filter = self._get_pinecone_filter(query.filter) try: # Query the index with the query embedding, filter, and top_k query_response = self.index.query( # namespace=namespace, top_k=query.top_k, vector=query.embedding, filter=pinecone_filter, include_metadata=True, ) except Exception as e: logger.error(f"Error querying index: {e}") raise e query_results: List[DocumentChunkWithScore] = [] for result in query_response.matches: score = result.score metadata = result.metadata # Remove document id and text from metadata and store it in a new variable metadata_without_text = ( {key: value for key, value in metadata.items() if key != "text"} if metadata else None ) # If the source is not a valid Source in the Source enum, set it to None if ( metadata_without_text and "source" in metadata_without_text and metadata_without_text["source"] not in Source.__members__ ): metadata_without_text["source"] = None # Create a document chunk with score object with the result data result = DocumentChunkWithScore( id=result.id, score=score, text=str(metadata["text"]) if metadata and "text" in metadata else "", metadata=metadata_without_text, ) query_results.append(result) return QueryResult(query=query.query, results=query_results) # Use asyncio.gather to run multiple _single_query coroutines concurrently and collect their results results: List[QueryResult] = await asyncio.gather( *[_single_query(query) for query in queries] ) return results @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(3)) async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes vectors by ids, filter, or everything from the index. """ # Delete all vectors from the index if delete_all is True if delete_all: try: logger.info(f"Deleting all vectors from index") self.index.delete(delete_all=True) logger.info(f"Deleted all vectors successfully") return True except Exception as e: logger.error(f"Error deleting all vectors: {e}") raise e # Convert the metadata filter object to a dict with pinecone filter expressions pinecone_filter = self._get_pinecone_filter(filter) # Delete vectors that match the filter from the index if the filter is not empty if pinecone_filter != {}: try: logger.info(f"Deleting vectors with filter {pinecone_filter}") self.index.delete(filter=pinecone_filter) logger.info(f"Deleted vectors with filter successfully") except Exception as e: logger.error(f"Error deleting vectors with filter: {e}") raise e # Delete vectors that match the document ids from the index if the ids list is not empty if ids is not None and len(ids) > 0: try: logger.info(f"Deleting vectors with ids {ids}") pinecone_filter = {"document_id": {"$in": ids}} self.index.delete(filter=pinecone_filter) # type: ignore logger.info(f"Deleted vectors with ids successfully") except Exception as e: logger.error(f"Error deleting vectors with ids: {e}") raise e return True def _get_pinecone_filter( self, filter: Optional[DocumentMetadataFilter] = None ) -> Dict[str, Any]: if filter is None: return {} pinecone_filter = {} # For each field in the MetadataFilter, check if it has a value and add the corresponding pinecone filter expression # For start_date and end_date, uses the $gte and $lte operators respectively # For other fields, uses the $eq operator for field, value in filter.dict().items(): if value is not None: if field == "start_date": pinecone_filter["created_at"] = pinecone_filter.get( "created_at", {} ) pinecone_filter["created_at"]["$gte"] = to_unix_timestamp(value) elif field == "end_date": pinecone_filter["created_at"] = pinecone_filter.get( "created_at", {} ) pinecone_filter["created_at"]["$lte"] = to_unix_timestamp(value) else: pinecone_filter[field] = value return pinecone_filter def _get_pinecone_metadata( self, metadata: Optional[DocumentChunkMetadata] = None ) -> Dict[str, Any]: if metadata is None: return {} pinecone_metadata = {} # For each field in the Metadata, check if it has a value and add it to the pinecone metadata dict # For fields that are dates, convert them to unix timestamps for field, value in metadata.dict().items(): if value is not None: if field in ["created_at"]: pinecone_metadata[field] = to_unix_timestamp(value) else: pinecone_metadata[field] = value return pinecone_metadata