datastore/providers/milvus_datastore.py (413 lines of code) (raw):

import json import os import asyncio from loguru import logger from typing import Dict, List, Optional from pymilvus import ( Collection, connections, utility, FieldSchema, DataType, CollectionSchema, MilvusException, ) from uuid import uuid4 from services.date import to_unix_timestamp from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, Source, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, DocumentChunkWithScore, ) MILVUS_COLLECTION = os.environ.get("MILVUS_COLLECTION") or "c" + uuid4().hex MILVUS_HOST = os.environ.get("MILVUS_HOST") or "localhost" MILVUS_PORT = os.environ.get("MILVUS_PORT") or 19530 MILVUS_USER = os.environ.get("MILVUS_USER") MILVUS_PASSWORD = os.environ.get("MILVUS_PASSWORD") MILVUS_USE_SECURITY = False if MILVUS_PASSWORD is None else True MILVUS_INDEX_PARAMS = os.environ.get("MILVUS_INDEX_PARAMS") MILVUS_SEARCH_PARAMS = os.environ.get("MILVUS_SEARCH_PARAMS") MILVUS_CONSISTENCY_LEVEL = os.environ.get("MILVUS_CONSISTENCY_LEVEL") UPSERT_BATCH_SIZE = 100 OUTPUT_DIM = int(os.environ.get("EMBEDDING_DIMENSION", 256)) EMBEDDING_FIELD = "embedding" class Required: pass # The fields names that we are going to be storing within Milvus, the field declaration for schema creation, and the default value SCHEMA_V1 = [ ( "pk", FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True), Required, ), ( EMBEDDING_FIELD, FieldSchema(name=EMBEDDING_FIELD, dtype=DataType.FLOAT_VECTOR, dim=OUTPUT_DIM), Required, ), ( "text", FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), Required, ), ( "document_id", FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=65535), "", ), ( "source_id", FieldSchema(name="source_id", dtype=DataType.VARCHAR, max_length=65535), "", ), ( "id", FieldSchema( name="id", dtype=DataType.VARCHAR, max_length=65535, ), "", ), ( "source", FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=65535), "", ), ("url", FieldSchema(name="url", dtype=DataType.VARCHAR, max_length=65535), ""), ("created_at", FieldSchema(name="created_at", dtype=DataType.INT64), -1), ( "author", FieldSchema(name="author", dtype=DataType.VARCHAR, max_length=65535), "", ), ] # V2 schema, remomve the "pk" field SCHEMA_V2 = SCHEMA_V1[1:] SCHEMA_V2[4][1].is_primary = True class MilvusDataStore(DataStore): def __init__( self, create_new: Optional[bool] = False, consistency_level: str = "Bounded", ): """Create a Milvus DataStore. The Milvus Datastore allows for storing your indexes and metadata within a Milvus instance. Args: create_new (Optional[bool], optional): Whether to overwrite if collection already exists. Defaults to True. consistency_level(str, optional): Specify the collection consistency level. Defaults to "Bounded" for search performance. Set to "Strong" in test cases for result validation. """ # Overwrite the default consistency level by MILVUS_CONSISTENCY_LEVEL self._consistency_level = MILVUS_CONSISTENCY_LEVEL or consistency_level self._create_connection() self._create_collection(MILVUS_COLLECTION, create_new) # type: ignore self._create_index() def _get_schema(self): return SCHEMA_V1 if self._schema_ver == "V1" else SCHEMA_V2 def _create_connection(self): try: self.alias = "" # Check if the connection already exists for x in connections.list_connections(): addr = connections.get_connection_addr(x[0]) if ( x[1] and ("address" in addr) and (addr["address"] == "{}:{}".format(MILVUS_HOST, MILVUS_PORT)) ): self.alias = x[0] logger.info( "Reuse connection to Milvus server '{}:{}' with alias '{:s}'".format( MILVUS_HOST, MILVUS_PORT, self.alias ) ) break # Connect to the Milvus instance using the passed in Environment variables if len(self.alias) == 0: self.alias = uuid4().hex connections.connect( alias=self.alias, host=MILVUS_HOST, port=MILVUS_PORT, user=MILVUS_USER, # type: ignore password=MILVUS_PASSWORD, # type: ignore secure=MILVUS_USE_SECURITY, ) logger.info( "Create connection to Milvus server '{}:{}' with alias '{:s}'".format( MILVUS_HOST, MILVUS_PORT, self.alias ) ) except Exception as e: logger.error( "Failed to create connection to Milvus server '{}:{}', error: {}".format( MILVUS_HOST, MILVUS_PORT, e ) ) def _create_collection(self, collection_name, create_new: bool) -> None: """Create a collection based on environment and passed in variables. Args: create_new (bool): Whether to overwrite if collection already exists. """ try: self._schema_ver = "V1" # If the collection exists and create_new is True, drop the existing collection if utility.has_collection(collection_name, using=self.alias) and create_new: utility.drop_collection(collection_name, using=self.alias) # Check if the collection doesnt exist if utility.has_collection(collection_name, using=self.alias) is False: # If it doesnt exist use the field params from init to create a new schem schema = [field[1] for field in SCHEMA_V2] schema = CollectionSchema(schema) # Use the schema to create a new collection self.col = Collection( collection_name, schema=schema, using=self.alias, consistency_level=self._consistency_level, ) self._schema_ver = "V2" logger.info( "Create Milvus collection '{}' with schema {} and consistency level {}".format( collection_name, self._schema_ver, self._consistency_level ) ) else: # If the collection exists, point to it self.col = Collection(collection_name, using=self.alias) # type: ignore # Which sechma is used for field in self.col.schema.fields: if field.name == "id" and field.is_primary: self._schema_ver = "V2" break logger.info( "Milvus collection '{}' already exists with schema {}".format( collection_name, self._schema_ver ) ) except Exception as e: logger.error( "Failed to create collection '{}', error: {}".format(collection_name, e) ) def _create_index(self): # TODO: verify index/search params passed by os.environ self.index_params = MILVUS_INDEX_PARAMS or None self.search_params = MILVUS_SEARCH_PARAMS or None try: # If no index on the collection, create one if len(self.col.indexes) == 0: if self.index_params is not None: # Convert the string format to JSON format parameters passed by MILVUS_INDEX_PARAMS self.index_params = json.loads(self.index_params) logger.info("Create Milvus index: {}".format(self.index_params)) # Create an index on the 'embedding' field with the index params found in init self.col.create_index( EMBEDDING_FIELD, index_params=self.index_params ) else: # If no index param supplied, to first create an HNSW index for Milvus try: i_p = { "metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}, } logger.info( "Attempting creation of Milvus '{}' index".format( i_p["index_type"] ) ) self.col.create_index(EMBEDDING_FIELD, index_params=i_p) self.index_params = i_p logger.info( "Creation of Milvus '{}' index successful".format( i_p["index_type"] ) ) # If create fails, most likely due to being Zilliz Cloud instance, try to create an AutoIndex except MilvusException: logger.info("Attempting creation of Milvus default index") i_p = { "metric_type": "IP", "index_type": "AUTOINDEX", "params": {}, } self.col.create_index(EMBEDDING_FIELD, index_params=i_p) self.index_params = i_p logger.info("Creation of Milvus default index successful") # If an index already exists, grab its params else: # How about if the first index is not vector index? for index in self.col.indexes: idx = index.to_dict() if idx["field"] == EMBEDDING_FIELD: logger.info("Index already exists: {}".format(idx)) self.index_params = idx["index_param"] break self.col.load() if self.search_params is not None: # Convert the string format to JSON format parameters passed by MILVUS_SEARCH_PARAMS self.search_params = json.loads(self.search_params) else: # The default search params metric_type = "IP" if "metric_type" in self.index_params: metric_type = self.index_params["metric_type"] default_search_params = { "IVF_FLAT": {"metric_type": metric_type, "params": {"nprobe": 10}}, "IVF_SQ8": {"metric_type": metric_type, "params": {"nprobe": 10}}, "IVF_PQ": {"metric_type": metric_type, "params": {"nprobe": 10}}, "HNSW": {"metric_type": metric_type, "params": {"ef": 10}}, "RHNSW_FLAT": {"metric_type": metric_type, "params": {"ef": 10}}, "RHNSW_SQ": {"metric_type": metric_type, "params": {"ef": 10}}, "RHNSW_PQ": {"metric_type": metric_type, "params": {"ef": 10}}, "IVF_HNSW": { "metric_type": metric_type, "params": {"nprobe": 10, "ef": 10}, }, "ANNOY": {"metric_type": metric_type, "params": {"search_k": 10}}, "AUTOINDEX": {"metric_type": metric_type, "params": {}}, } # Set the search params self.search_params = default_search_params[ self.index_params["index_type"] ] logger.info("Milvus search parameters: {}".format(self.search_params)) except Exception as e: logger.error("Failed to create index, error: {}".format(e)) async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """Upsert chunks into the datastore. Args: chunks (Dict[str, List[DocumentChunk]]): A list of DocumentChunks to insert Raises: e: Error in upserting data. Returns: List[str]: The document_id's that were inserted. """ try: # The doc id's to return for the upsert doc_ids: List[str] = [] # List to collect all the insert data, skip the "pk" for schema V1 offset = 1 if self._schema_ver == "V1" else 0 insert_data = [[] for _ in range(len(self._get_schema()) - offset)] # Go through each document chunklist and grab the data for doc_id, chunk_list in chunks.items(): # Append the doc_id to the list we are returning doc_ids.append(doc_id) # Examine each chunk in the chunklist for chunk in chunk_list: # Extract data from the chunk list_of_data = self._get_values(chunk) # Check if the data is valid if list_of_data is not None: # Append each field to the insert_data for x in range(len(insert_data)): insert_data[x].append(list_of_data[x]) # Slice up our insert data into batches batches = [ insert_data[i : i + UPSERT_BATCH_SIZE] for i in range(0, len(insert_data), UPSERT_BATCH_SIZE) ] # Attempt to insert each batch into our collection # batch data can work with both V1 and V2 schema for batch in batches: if len(batch[0]) != 0: try: logger.info(f"Upserting batch of size {len(batch[0])}") self.col.insert(batch) logger.info(f"Upserted batch successfully") except Exception as e: logger.error(f"Failed to insert batch records, error: {e}") raise e # This setting perfoms flushes after insert. Small insert == bad to use # self.col.flush() return doc_ids except Exception as e: logger.error("Failed to insert records, error: {}".format(e)) return [] def _get_values(self, chunk: DocumentChunk) -> List[any] | None: # type: ignore """Convert the chunk into a list of values to insert whose indexes align with fields. Args: chunk (DocumentChunk): The chunk to convert. Returns: List (any): The values to insert. """ # Convert DocumentChunk and its sub models to dict values = chunk.dict() # Unpack the metadata into the same dict meta = values.pop("metadata") values.update(meta) # Convert date to int timestamp form if values["created_at"]: values["created_at"] = to_unix_timestamp(values["created_at"]) # If source exists, change from Source object to the string value it holds if values["source"]: values["source"] = values["source"].value # List to collect data we will return ret = [] # Grab data responding to each field, excluding the hidden auto pk field for schema V1 offset = 1 if self._schema_ver == "V1" else 0 for key, _, default in self._get_schema()[offset:]: # Grab the data at the key and default to our defaults set in init x = values.get(key) or default # If one of our required fields is missing, ignore the entire entry if x is Required: logger.info("Chunk " + values["id"] + " missing " + key + " skipping") return None # Add the corresponding value if it passes the tests ret.append(x) return ret async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """Query the QueryWithEmbedding against the MilvusDocumentSearch Search the embedding and its filter in the collection. Args: queries (List[QueryWithEmbedding]): The list of searches to perform. Returns: List[QueryResult]: Results for each search. """ # Async to perform the query, adapted from pinecone implementation async def _single_query(query: QueryWithEmbedding) -> QueryResult: try: filter = None # Set the filter to expression that is valid for Milvus if query.filter is not None: # Either a valid filter or None will be returned filter = self._get_filter(query.filter) # Perform our search return_from = 2 if self._schema_ver == "V1" else 1 res = self.col.search( data=[query.embedding], anns_field=EMBEDDING_FIELD, param=self.search_params, limit=query.top_k, expr=filter, output_fields=[ field[0] for field in self._get_schema()[return_from:] ], # Ignoring pk, embedding ) # Results that will hold our DocumentChunkWithScores results = [] # Parse every result for our search for hit in res[0]: # type: ignore # The distance score for the search result, falls under DocumentChunkWithScore score = hit.score # Our metadata info, falls under DocumentChunkMetadata metadata = {} # Grab the values that correspond to our fields, ignore pk and embedding. for x in [field[0] for field in self._get_schema()[return_from:]]: metadata[x] = hit.entity.get(x) # If the source isn't valid, convert to None if metadata["source"] not in Source.__members__: metadata["source"] = None # Text falls under the DocumentChunk text = metadata.pop("text") # Id falls under the DocumentChunk ids = metadata.pop("id") chunk = DocumentChunkWithScore( id=ids, score=score, text=text, metadata=DocumentChunkMetadata(**metadata), ) results.append(chunk) # TODO: decide on doing queries to grab the embedding itself, slows down performance as double query occurs return QueryResult(query=query.query, results=results) except Exception as e: logger.error("Failed to query, error: {}".format(e)) return QueryResult(query=query.query, results=[]) results: List[QueryResult] = await asyncio.gather( *[_single_query(query) for query in queries] ) return results async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """Delete the entities based either on the chunk_id of the vector, Args: ids (Optional[List[str]], optional): The document_ids to delete. Defaults to None. filter (Optional[DocumentMetadataFilter], optional): The filter to delete by. Defaults to None. delete_all (Optional[bool], optional): Whether to drop the collection and recreate it. Defaults to None. """ # If deleting all, drop and create the new collection if delete_all: coll_name = self.col.name logger.info( "Delete the entire collection {} and create new one".format(coll_name) ) # Release the collection from memory self.col.release() # Drop the collection self.col.drop() # Recreate the new collection self._create_collection(coll_name, True) self._create_index() return True # Keep track of how many we have deleted for later printing delete_count = 0 batch_size = 100 pk_name = "pk" if self._schema_ver == "V1" else "id" try: # According to the api design, the ids is a list of document_id, # document_id is not primary key, use query+delete to workaround, # in future version we can delete by expression if (ids is not None) and len(ids) > 0: # Add quotation marks around the string format id ids = ['"' + str(id) + '"' for id in ids] # Query for the pk's of entries that match id's ids = self.col.query(f"document_id in [{','.join(ids)}]") # Convert to list of pks pks = [str(entry[pk_name]) for entry in ids] # type: ignore # for schema V2, the "id" is varchar, rewrite the expression if self._schema_ver != "V1": pks = ['"' + pk + '"' for pk in pks] # Delete by ids batch by batch(avoid too long expression) logger.info( "Apply {:d} deletions to schema {:s}".format( len(pks), self._schema_ver ) ) while len(pks) > 0: batch_pks = pks[:batch_size] pks = pks[batch_size:] # Delete the entries batch by batch res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") # Increment our deleted count delete_count += int(res.delete_count) # type: ignore except Exception as e: logger.error("Failed to delete by ids, error: {}".format(e)) try: # Check if empty filter if filter is not None: # Convert filter to milvus expression filter = self._get_filter(filter) # type: ignore # Check if there is anything to filter if len(filter) != 0: # type: ignore # Query for the pk's of entries that match filter res = self.col.query(filter) # type: ignore # Convert to list of pks pks = [str(entry[pk_name]) for entry in res] # type: ignore # for schema V2, the "id" is varchar, rewrite the expression if self._schema_ver != "V1": pks = ['"' + pk + '"' for pk in pks] # Check to see if there are valid pk's to delete, delete batch by batch(avoid too long expression) while len(pks) > 0: # type: ignore batch_pks = pks[:batch_size] pks = pks[batch_size:] # Delete the entries batch by batch res = self.col.delete(f"{pk_name} in [{','.join(batch_pks)}]") # type: ignore # Increment our delete count delete_count += int(res.delete_count) # type: ignore except Exception as e: logger.error("Failed to delete by filter, error: {}".format(e)) logger.info("{:d} records deleted".format(delete_count)) # This setting performs flushes after delete. Small delete == bad to use # self.col.flush() return True def _get_filter(self, filter: DocumentMetadataFilter) -> Optional[str]: """Converts a DocumentMetdataFilter to the expression that Milvus takes. Args: filter (DocumentMetadataFilter): The Filter to convert to Milvus expression. Returns: Optional[str]: The filter if valid, otherwise None. """ filters = [] # Go through all the fields and their values for field, value in filter.dict().items(): # Check if the Value is empty if value is not None: # Convert start_date to int and add greater than or equal logic if field == "start_date": filters.append( "(created_at >= " + str(to_unix_timestamp(value)) + ")" ) # Convert end_date to int and add less than or equal logic elif field == "end_date": filters.append( "(created_at <= " + str(to_unix_timestamp(value)) + ")" ) # Convert Source to its string value and check equivalency elif field == "source": filters.append("(" + field + ' == "' + str(value.value) + '")') # Check equivalency of rest of string fields else: filters.append("(" + field + ' == "' + str(value) + '")') # Join all our expressions with `and`` return " and ".join(filters)