datastore/providers/redis_datastore.py (40 lines of code) (raw):

import asyncio import os import re import json import redis.asyncio as redis import numpy as np from redis.commands.search.query import Query as RediSearchQuery from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.field import ( TagField, TextField, NumericField, VectorField, ) from loguru import logger from typing import Dict, List, Optional from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentMetadataFilter, DocumentChunkWithScore, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, ) from services.date import to_unix_timestamp # Read environment variables for Redis REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") REDIS_INDEX_NAME = os.environ.get("REDIS_INDEX_NAME", "index") REDIS_DOC_PREFIX = os.environ.get("REDIS_DOC_PREFIX", "doc") REDIS_DISTANCE_METRIC = os.environ.get("REDIS_DISTANCE_METRIC", "COSINE") REDIS_INDEX_TYPE = os.environ.get("REDIS_INDEX_TYPE", "FLAT") assert REDIS_INDEX_TYPE in ("FLAT", "HNSW") # OpenAI Embeddings Dimension VECTOR_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256)) # RediSearch constants REDIS_REQUIRED_MODULES = [ {"name": "search", "ver": 20600}, {"name": "ReJSON", "ver": 20404}, ] REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&()\-+=~\/ ]") # Helper functions def unpack_schema(d: dict): for v in d.values(): if isinstance(v, dict): yield from unpack_schema(v) else: yield v async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]): installed_modules = (await client.info()).get("modules", []) installed_modules = {module["name"]: module for module in installed_modules} for module in modules: if module["name"] not in installed_modules or int( installed_modules[module["name"]]["ver"] ) < int(module["ver"]): error_message = ( "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " "Please refer to Redis Stack docs: https://redis.io/docs/stack/" ) logger.error(error_message) raise AttributeError(error_message) class RedisDataStore(DataStore): def __init__(self, client: redis.Redis, redisearch_schema: dict): self.client = client self._schema = redisearch_schema # Init default metadata with sentinel values in case the document written has no metadata self._default_metadata = { field: (0 if field == "created_at" else "_null_") for field in redisearch_schema["metadata"] } ### Redis Helper Methods ### @classmethod async def init(cls, **kwargs): """ Setup the index if it does not exist. """ try: # Connect to the Redis Client logger.info("Connecting to Redis") client = redis.Redis( host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD ) except Exception as e: logger.error(f"Error setting up Redis: {e}") raise e await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES) dim = kwargs.get("dim", VECTOR_DIMENSION) redisearch_schema = { "metadata": { "document_id": TagField( "$.metadata.document_id", as_name="document_id" ), "source_id": TagField("$.metadata.source_id", as_name="source_id"), "source": TagField("$.metadata.source", as_name="source"), "author": TextField("$.metadata.author", as_name="author"), "created_at": NumericField( "$.metadata.created_at", as_name="created_at" ), }, "embedding": VectorField( "$.embedding", REDIS_INDEX_TYPE, { "TYPE": "FLOAT64", "DIM": dim, "DISTANCE_METRIC": REDIS_DISTANCE_METRIC, }, as_name="embedding", ), } try: # Check for existence of RediSearch Index await client.ft(REDIS_INDEX_NAME).info() logger.info(f"RediSearch index {REDIS_INDEX_NAME} already exists") except: # Create the RediSearch Index logger.info(f"Creating new RediSearch index {REDIS_INDEX_NAME}") definition = IndexDefinition( prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON ) fields = list(unpack_schema(redisearch_schema)) logger.info(f"Creating index with fields: {fields}") await client.ft(REDIS_INDEX_NAME).create_index( fields=fields, definition=definition ) return cls(client, redisearch_schema) @staticmethod def _redis_key(document_id: str, chunk_id: str) -> str: """ Create the JSON key for document chunks in Redis. Args: document_id (str): Document Identifier chunk_id (str): Chunk Identifier Returns: str: JSON key string. """ return f"doc:{document_id}:chunk:{chunk_id}" @staticmethod def _escape(value: str) -> str: """ Escape filter value. Args: value (str): Value to escape. Returns: str: Escaped filter value for RediSearch. """ def escape_symbol(match) -> str: value = match.group(0) return f"\\{value}" return REDIS_DEFAULT_ESCAPED_CHARS.sub(escape_symbol, value) def _get_redis_chunk(self, chunk: DocumentChunk) -> dict: """ Convert DocumentChunk into a JSON object for storage in Redis. Args: chunk (DocumentChunk): Chunk of a Document. Returns: dict: JSON object for storage in Redis. """ # Convert chunk -> dict data = chunk.__dict__ metadata = chunk.metadata.__dict__ data["chunk_id"] = data.pop("id") # Prep Redis Metadata redis_metadata = dict(self._default_metadata) if metadata: for field, value in metadata.items(): if value: if field == "created_at": redis_metadata[field] = to_unix_timestamp(value) # type: ignore else: redis_metadata[field] = value data["metadata"] = redis_metadata return data def _get_redis_query(self, query: QueryWithEmbedding) -> RediSearchQuery: """ Convert a QueryWithEmbedding into a RediSearchQuery. Args: query (QueryWithEmbedding): Search query. Returns: RediSearchQuery: Query for RediSearch. """ filter_str: str = "" # RediSearch field type to query string def _typ_to_str(typ, field, value) -> str: # type: ignore if isinstance(typ, TagField): return f"@{field}:{{{self._escape(value)}}} " elif isinstance(typ, TextField): return f"@{field}:{value} " elif isinstance(typ, NumericField): num = to_unix_timestamp(value) match field: case "start_date": return f"@{field}:[{num} +inf] " case "end_date": return f"@{field}:[-inf {num}] " # Build filter if query.filter: redisearch_schema = self._schema for field, value in query.filter.__dict__.items(): if not value: continue if field in redisearch_schema: filter_str += _typ_to_str(redisearch_schema[field], field, value) elif field in redisearch_schema["metadata"]: if field == "source": # handle the enum value = value.value filter_str += _typ_to_str( redisearch_schema["metadata"][field], field, value ) elif field in ["start_date", "end_date"]: filter_str += _typ_to_str( redisearch_schema["metadata"]["created_at"], field, value ) # Postprocess filter string filter_str = filter_str.strip() filter_str = filter_str if filter_str else "*" # Prepare query string query_str = ( f"({filter_str})=>[KNN {query.top_k} @embedding $embedding as score]" ) return ( RediSearchQuery(query_str) .sort_by("score") .paging(0, query.top_k) .dialect(2) ) async def _redis_delete(self, keys: List[str]): """ Delete a list of keys from Redis. Args: keys (List[str]): List of keys to delete. """ # Delete the keys await asyncio.gather(*[self.client.delete(key) for key in keys]) ####### async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a list of list of document chunks and inserts them into the database. Return a list of document ids. """ # Initialize a list of ids to return doc_ids: List[str] = [] # Loop through the dict items for doc_id, chunk_list in chunks.items(): # Append the id to the ids list doc_ids.append(doc_id) # Write chunks in a pipelines async with self.client.pipeline(transaction=False) as pipe: for chunk in chunk_list: key = self._redis_key(doc_id, chunk.id) data = self._get_redis_chunk(chunk) await pipe.json().set(key, "$", data) await pipe.execute() return doc_ids async def _query( self, queries: List[QueryWithEmbedding], ) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ # Prepare query responses and results object results: List[QueryResult] = [] # Gather query results in a pipeline logger.info(f"Gathering {len(queries)} query results") for query in queries: logger.debug(f"Query: {query.query}") query_results: List[DocumentChunkWithScore] = [] # Extract Redis query redis_query: RediSearchQuery = self._get_redis_query(query) embedding = np.array(query.embedding, dtype=np.float64).tobytes() # Perform vector search query_response = await self.client.ft(REDIS_INDEX_NAME).search( redis_query, {"embedding": embedding} ) # Iterate through the most similar documents for doc in query_response.docs: # Load JSON data doc_json = json.loads(doc.json) # Create document chunk object with score result = DocumentChunkWithScore( id=doc_json["metadata"]["document_id"], score=doc.score, text=doc_json["text"], metadata=doc_json["metadata"], ) query_results.append(result) # Add to overall results results.append(QueryResult(query=query.query, results=query_results)) return results async def _find_keys(self, pattern: str) -> List[str]: return [key async for key in self.client.scan_iter(pattern)] async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes vectors by ids, filter, or everything in the datastore. Returns whether the operation was successful. """ # Delete all vectors from the index if delete_all is True if delete_all: try: logger.info(f"Deleting all documents from index") await self.client.ft(REDIS_INDEX_NAME).dropindex(True) logger.info(f"Deleted all documents successfully") return True except Exception as e: logger.error(f"Error deleting all documents: {e}") raise e # Delete by filter if filter: # TODO - extend this to work with other metadata filters? if filter.document_id: try: keys = await self._find_keys( f"{REDIS_DOC_PREFIX}:{filter.document_id}:*" ) await self._redis_delete(keys) logger.info(f"Deleted document {filter.document_id} successfully") except Exception as e: logger.error(f"Error deleting document {filter.document_id}: {e}") raise e # Delete by explicit ids (Redis keys) if ids: try: logger.info(f"Deleting document ids {ids}") keys = [] # find all keys associated with the document ids for document_id in ids: doc_keys = await self._find_keys( pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*" ) keys.extend(doc_keys) # delete all keys logger.info(f"Deleting {len(keys)} keys from Redis") await self._redis_delete(keys) except Exception as e: logger.error(f"Error deleting ids: {e}") raise e return True