datastore/providers/weaviate_datastore.py (307 lines of code) (raw):

import asyncio import os import re import uuid from typing import Dict, List, Optional import weaviate from loguru import logger from weaviate import Client from weaviate.util import generate_uuid5 from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, Source, ) WEAVIATE_URL_DEFAULT = "http://localhost:8080" WEAVIATE_CLASS = os.environ.get("WEAVIATE_CLASS", "OpenAIDocument") WEAVIATE_BATCH_SIZE = int(os.environ.get("WEAVIATE_BATCH_SIZE", 20)) WEAVIATE_BATCH_DYNAMIC = os.environ.get("WEAVIATE_BATCH_DYNAMIC", False) WEAVIATE_BATCH_TIMEOUT_RETRIES = int(os.environ.get("WEAVIATE_TIMEOUT_RETRIES", 3)) WEAVIATE_BATCH_NUM_WORKERS = int(os.environ.get("WEAVIATE_BATCH_NUM_WORKERS", 1)) SCHEMA = { "class": WEAVIATE_CLASS, "description": "The main class", "properties": [ { "name": "chunk_id", "dataType": ["string"], "description": "The chunk id", }, { "name": "document_id", "dataType": ["string"], "description": "The document id", }, { "name": "text", "dataType": ["text"], "description": "The chunk's text", }, { "name": "source", "dataType": ["string"], "description": "The source of the data", }, { "name": "source_id", "dataType": ["string"], "description": "The source id", }, { "name": "url", "dataType": ["string"], "description": "The source url", }, { "name": "created_at", "dataType": ["date"], "description": "Creation date of document", }, { "name": "author", "dataType": ["string"], "description": "Document author", }, ], } def extract_schema_properties(schema): properties = schema["properties"] return {property["name"] for property in properties} class WeaviateDataStore(DataStore): def handle_errors(self, results: Optional[List[dict]]) -> List[str]: if not self or not results: return [] error_messages = [] for result in results: if ( "result" not in result or "errors" not in result["result"] or "error" not in result["result"]["errors"] ): continue for message in result["result"]["errors"]["error"]: error_messages.append(message["message"]) logger.error(message["message"]) return error_messages def __init__(self): auth_credentials = self._build_auth_credentials() url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT) logger.debug( f"Connecting to weaviate instance at {url} with credential type {type(auth_credentials).__name__}" ) self.client = Client(url, auth_client_secret=auth_credentials) self.client.batch.configure( batch_size=WEAVIATE_BATCH_SIZE, dynamic=WEAVIATE_BATCH_DYNAMIC, # type: ignore callback=self.handle_errors, # type: ignore timeout_retries=WEAVIATE_BATCH_TIMEOUT_RETRIES, num_workers=WEAVIATE_BATCH_NUM_WORKERS, ) if self.client.schema.contains(SCHEMA): current_schema = self.client.schema.get(WEAVIATE_CLASS) current_schema_properties = extract_schema_properties(current_schema) logger.debug( f"Found index {WEAVIATE_CLASS} with properties {current_schema_properties}" ) logger.debug("Will reuse this schema") else: new_schema_properties = extract_schema_properties(SCHEMA) logger.debug( f"Creating collection {WEAVIATE_CLASS} with properties {new_schema_properties}" ) self.client.schema.create_class(SCHEMA) @staticmethod def _build_auth_credentials(): url = os.environ.get("WEAVIATE_URL", WEAVIATE_URL_DEFAULT) if WeaviateDataStore._is_wcs_domain(url): api_key = os.environ.get("WEAVIATE_API_KEY") if api_key is not None: return weaviate.auth.AuthApiKey(api_key=api_key) else: raise ValueError("WEAVIATE_API_KEY environment variable is not set") else: return None async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a list of list of document chunks and inserts them into the database. Return a list of document ids. """ doc_ids = [] with self.client.batch as batch: for doc_id, doc_chunks in chunks.items(): logger.debug(f"Upserting {doc_id} with {len(doc_chunks)} chunks") for doc_chunk in doc_chunks: # we generate a uuid regardless of the format of the document_id because # weaviate needs a uuid to store each document chunk and # a document chunk cannot share the same uuid doc_uuid = generate_uuid5(doc_chunk, WEAVIATE_CLASS) metadata = doc_chunk.metadata doc_chunk_dict = doc_chunk.dict() doc_chunk_dict.pop("metadata") for key, value in metadata.dict().items(): doc_chunk_dict[key] = value doc_chunk_dict["chunk_id"] = doc_chunk_dict.pop("id") doc_chunk_dict["source"] = ( doc_chunk_dict.pop("source").value if doc_chunk_dict["source"] else None ) embedding = doc_chunk_dict.pop("embedding") batch.add_data_object( uuid=doc_uuid, data_object=doc_chunk_dict, class_name=WEAVIATE_CLASS, vector=embedding, ) doc_ids.append(doc_id) batch.flush() return doc_ids async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ async def _single_query(query: QueryWithEmbedding) -> QueryResult: logger.debug(f"Query: {query.query}") if not hasattr(query, "filter") or not query.filter: result = ( self.client.query.get( WEAVIATE_CLASS, [ "chunk_id", "document_id", "text", "source", "source_id", "url", "created_at", "author", ], ) .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) .with_limit(query.top_k) # type: ignore .with_additional(["score", "vector"]) .do() ) else: filters_ = self.build_filters(query.filter) result = ( self.client.query.get( WEAVIATE_CLASS, [ "chunk_id", "document_id", "text", "source", "source_id", "url", "created_at", "author", ], ) .with_hybrid(query=query.query, alpha=0.5, vector=query.embedding) .with_where(filters_) .with_limit(query.top_k) # type: ignore .with_additional(["score", "vector"]) .do() ) query_results: List[DocumentChunkWithScore] = [] response = result["data"]["Get"][WEAVIATE_CLASS] for resp in response: result = DocumentChunkWithScore( id=resp["chunk_id"], text=resp["text"], # embedding=resp["_additional"]["vector"], score=resp["_additional"]["score"], metadata=DocumentChunkMetadata( document_id=resp["document_id"] if resp["document_id"] else "", source=Source(resp["source"]) if resp["source"] else None, source_id=resp["source_id"], url=resp["url"], created_at=resp["created_at"], author=resp["author"], ), ) query_results.append(result) return QueryResult(query=query.query, results=query_results) return await asyncio.gather(*[_single_query(query) for query in queries]) async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: # TODO """ Removes vectors by ids, filter, or everything in the datastore. Returns whether the operation was successful. """ if delete_all: logger.debug(f"Deleting all vectors in index {WEAVIATE_CLASS}") self.client.schema.delete_all() return True if ids: operands = [ {"path": ["document_id"], "operator": "Equal", "valueString": id} for id in ids ] where_clause = {"operator": "Or", "operands": operands} logger.debug(f"Deleting vectors from index {WEAVIATE_CLASS} with ids {ids}") result = self.client.batch.delete_objects( class_name=WEAVIATE_CLASS, where=where_clause, output="verbose" ) if not bool(result["results"]["successful"]): logger.debug( f"Failed to delete the following objects: {result['results']['objects']}" ) if filter: where_clause = self.build_filters(filter) logger.debug( f"Deleting vectors from index {WEAVIATE_CLASS} with filter {where_clause}" ) result = self.client.batch.delete_objects( class_name=WEAVIATE_CLASS, where=where_clause ) if not bool(result["results"]["successful"]): logger.debug( f"Failed to delete the following objects: {result['results']['objects']}" ) return True @staticmethod def build_filters(filter): if filter.source: filter.source = filter.source.value operands = [] filter_conditions = { "source": { "operator": "Equal", "value": "query.filter.source.value", "value_key": "valueString", }, "start_date": {"operator": "GreaterThanEqual", "value_key": "valueDate"}, "end_date": {"operator": "LessThanEqual", "value_key": "valueDate"}, "default": {"operator": "Equal", "value_key": "valueString"}, } for attr, value in filter.__dict__.items(): if value is not None: filter_condition = filter_conditions.get( attr, filter_conditions["default"] ) value_key = filter_condition["value_key"] operand = { "path": [ attr if not (attr == "start_date" or attr == "end_date") else "created_at" ], "operator": filter_condition["operator"], value_key: value, } operands.append(operand) return {"operator": "And", "operands": operands} @staticmethod def _is_valid_weaviate_id(candidate_id: str) -> bool: """ Check if candidate_id is a valid UUID for weaviate's use Weaviate supports UUIDs of version 3, 4 and 5. This function checks if the candidate_id is a valid UUID of one of these versions. See https://weaviate.io/developers/weaviate/more-resources/faq#q-are-there-restrictions-on-uuid-formatting-do-i-have-to-adhere-to-any-standards for more information. """ acceptable_version = [3, 4, 5] try: result = uuid.UUID(candidate_id) if result.version not in acceptable_version: return False else: return True except ValueError: return False @staticmethod def _is_wcs_domain(url: str) -> bool: """ Check if the given URL ends with ".weaviate.network" or ".weaviate.network/". Args: url (str): The URL to check. Returns: bool: True if the URL ends with the specified strings, False otherwise. """ pattern = r"\.(weaviate\.cloud|weaviate\.network)(/)?$" return bool(re.search(pattern, url))