datastore/providers/redis_datastore.py (40 lines of code) (raw):
import asyncio
import os
import re
import json
import redis.asyncio as redis
import numpy as np
from redis.commands.search.query import Query as RediSearchQuery
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.field import (
TagField,
TextField,
NumericField,
VectorField,
)
from loguru import logger
from typing import Dict, List, Optional
from datastore.datastore import DataStore
from models.models import (
DocumentChunk,
DocumentMetadataFilter,
DocumentChunkWithScore,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
)
from services.date import to_unix_timestamp
# Read environment variables for Redis
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD")
REDIS_INDEX_NAME = os.environ.get("REDIS_INDEX_NAME", "index")
REDIS_DOC_PREFIX = os.environ.get("REDIS_DOC_PREFIX", "doc")
REDIS_DISTANCE_METRIC = os.environ.get("REDIS_DISTANCE_METRIC", "COSINE")
REDIS_INDEX_TYPE = os.environ.get("REDIS_INDEX_TYPE", "FLAT")
assert REDIS_INDEX_TYPE in ("FLAT", "HNSW")
# OpenAI Embeddings Dimension
VECTOR_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256))
# RediSearch constants
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20600},
{"name": "ReJSON", "ver": 20404},
]
REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&()\-+=~\/ ]")
# Helper functions
def unpack_schema(d: dict):
for v in d.values():
if isinstance(v, dict):
yield from unpack_schema(v)
else:
yield v
async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
installed_modules = (await client.info()).get("modules", [])
installed_modules = {module["name"]: module for module in installed_modules}
for module in modules:
if module["name"] not in installed_modules or int(
installed_modules[module["name"]]["ver"]
) < int(module["ver"]):
error_message = (
"You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
)
logger.error(error_message)
raise AttributeError(error_message)
class RedisDataStore(DataStore):
def __init__(self, client: redis.Redis, redisearch_schema: dict):
self.client = client
self._schema = redisearch_schema
# Init default metadata with sentinel values in case the document written has no metadata
self._default_metadata = {
field: (0 if field == "created_at" else "_null_")
for field in redisearch_schema["metadata"]
}
### Redis Helper Methods ###
@classmethod
async def init(cls, **kwargs):
"""
Setup the index if it does not exist.
"""
try:
# Connect to the Redis Client
logger.info("Connecting to Redis")
client = redis.Redis(
host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD
)
except Exception as e:
logger.error(f"Error setting up Redis: {e}")
raise e
await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES)
dim = kwargs.get("dim", VECTOR_DIMENSION)
redisearch_schema = {
"metadata": {
"document_id": TagField(
"$.metadata.document_id", as_name="document_id"
),
"source_id": TagField("$.metadata.source_id", as_name="source_id"),
"source": TagField("$.metadata.source", as_name="source"),
"author": TextField("$.metadata.author", as_name="author"),
"created_at": NumericField(
"$.metadata.created_at", as_name="created_at"
),
},
"embedding": VectorField(
"$.embedding",
REDIS_INDEX_TYPE,
{
"TYPE": "FLOAT64",
"DIM": dim,
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
},
as_name="embedding",
),
}
try:
# Check for existence of RediSearch Index
await client.ft(REDIS_INDEX_NAME).info()
logger.info(f"RediSearch index {REDIS_INDEX_NAME} already exists")
except:
# Create the RediSearch Index
logger.info(f"Creating new RediSearch index {REDIS_INDEX_NAME}")
definition = IndexDefinition(
prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON
)
fields = list(unpack_schema(redisearch_schema))
logger.info(f"Creating index with fields: {fields}")
await client.ft(REDIS_INDEX_NAME).create_index(
fields=fields, definition=definition
)
return cls(client, redisearch_schema)
@staticmethod
def _redis_key(document_id: str, chunk_id: str) -> str:
"""
Create the JSON key for document chunks in Redis.
Args:
document_id (str): Document Identifier
chunk_id (str): Chunk Identifier
Returns:
str: JSON key string.
"""
return f"doc:{document_id}:chunk:{chunk_id}"
@staticmethod
def _escape(value: str) -> str:
"""
Escape filter value.
Args:
value (str): Value to escape.
Returns:
str: Escaped filter value for RediSearch.
"""
def escape_symbol(match) -> str:
value = match.group(0)
return f"\\{value}"
return REDIS_DEFAULT_ESCAPED_CHARS.sub(escape_symbol, value)
def _get_redis_chunk(self, chunk: DocumentChunk) -> dict:
"""
Convert DocumentChunk into a JSON object for storage
in Redis.
Args:
chunk (DocumentChunk): Chunk of a Document.
Returns:
dict: JSON object for storage in Redis.
"""
# Convert chunk -> dict
data = chunk.__dict__
metadata = chunk.metadata.__dict__
data["chunk_id"] = data.pop("id")
# Prep Redis Metadata
redis_metadata = dict(self._default_metadata)
if metadata:
for field, value in metadata.items():
if value:
if field == "created_at":
redis_metadata[field] = to_unix_timestamp(value) # type: ignore
else:
redis_metadata[field] = value
data["metadata"] = redis_metadata
return data
def _get_redis_query(self, query: QueryWithEmbedding) -> RediSearchQuery:
"""
Convert a QueryWithEmbedding into a RediSearchQuery.
Args:
query (QueryWithEmbedding): Search query.
Returns:
RediSearchQuery: Query for RediSearch.
"""
filter_str: str = ""
# RediSearch field type to query string
def _typ_to_str(typ, field, value) -> str: # type: ignore
if isinstance(typ, TagField):
return f"@{field}:{{{self._escape(value)}}} "
elif isinstance(typ, TextField):
return f"@{field}:{value} "
elif isinstance(typ, NumericField):
num = to_unix_timestamp(value)
match field:
case "start_date":
return f"@{field}:[{num} +inf] "
case "end_date":
return f"@{field}:[-inf {num}] "
# Build filter
if query.filter:
redisearch_schema = self._schema
for field, value in query.filter.__dict__.items():
if not value:
continue
if field in redisearch_schema:
filter_str += _typ_to_str(redisearch_schema[field], field, value)
elif field in redisearch_schema["metadata"]:
if field == "source": # handle the enum
value = value.value
filter_str += _typ_to_str(
redisearch_schema["metadata"][field], field, value
)
elif field in ["start_date", "end_date"]:
filter_str += _typ_to_str(
redisearch_schema["metadata"]["created_at"], field, value
)
# Postprocess filter string
filter_str = filter_str.strip()
filter_str = filter_str if filter_str else "*"
# Prepare query string
query_str = (
f"({filter_str})=>[KNN {query.top_k} @embedding $embedding as score]"
)
return (
RediSearchQuery(query_str)
.sort_by("score")
.paging(0, query.top_k)
.dialect(2)
)
async def _redis_delete(self, keys: List[str]):
"""
Delete a list of keys from Redis.
Args:
keys (List[str]): List of keys to delete.
"""
# Delete the keys
await asyncio.gather(*[self.client.delete(key) for key in keys])
#######
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of list of document chunks and inserts them into the database.
Return a list of document ids.
"""
# Initialize a list of ids to return
doc_ids: List[str] = []
# Loop through the dict items
for doc_id, chunk_list in chunks.items():
# Append the id to the ids list
doc_ids.append(doc_id)
# Write chunks in a pipelines
async with self.client.pipeline(transaction=False) as pipe:
for chunk in chunk_list:
key = self._redis_key(doc_id, chunk.id)
data = self._get_redis_chunk(chunk)
await pipe.json().set(key, "$", data)
await pipe.execute()
return doc_ids
async def _query(
self,
queries: List[QueryWithEmbedding],
) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and
returns a list of query results with matching document chunks and scores.
"""
# Prepare query responses and results object
results: List[QueryResult] = []
# Gather query results in a pipeline
logger.info(f"Gathering {len(queries)} query results")
for query in queries:
logger.debug(f"Query: {query.query}")
query_results: List[DocumentChunkWithScore] = []
# Extract Redis query
redis_query: RediSearchQuery = self._get_redis_query(query)
embedding = np.array(query.embedding, dtype=np.float64).tobytes()
# Perform vector search
query_response = await self.client.ft(REDIS_INDEX_NAME).search(
redis_query, {"embedding": embedding}
)
# Iterate through the most similar documents
for doc in query_response.docs:
# Load JSON data
doc_json = json.loads(doc.json)
# Create document chunk object with score
result = DocumentChunkWithScore(
id=doc_json["metadata"]["document_id"],
score=doc.score,
text=doc_json["text"],
metadata=doc_json["metadata"],
)
query_results.append(result)
# Add to overall results
results.append(QueryResult(query=query.query, results=query_results))
return results
async def _find_keys(self, pattern: str) -> List[str]:
return [key async for key in self.client.scan_iter(pattern)]
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes vectors by ids, filter, or everything in the datastore.
Returns whether the operation was successful.
"""
# Delete all vectors from the index if delete_all is True
if delete_all:
try:
logger.info(f"Deleting all documents from index")
await self.client.ft(REDIS_INDEX_NAME).dropindex(True)
logger.info(f"Deleted all documents successfully")
return True
except Exception as e:
logger.error(f"Error deleting all documents: {e}")
raise e
# Delete by filter
if filter:
# TODO - extend this to work with other metadata filters?
if filter.document_id:
try:
keys = await self._find_keys(
f"{REDIS_DOC_PREFIX}:{filter.document_id}:*"
)
await self._redis_delete(keys)
logger.info(f"Deleted document {filter.document_id} successfully")
except Exception as e:
logger.error(f"Error deleting document {filter.document_id}: {e}")
raise e
# Delete by explicit ids (Redis keys)
if ids:
try:
logger.info(f"Deleting document ids {ids}")
keys = []
# find all keys associated with the document ids
for document_id in ids:
doc_keys = await self._find_keys(
pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*"
)
keys.extend(doc_keys)
# delete all keys
logger.info(f"Deleting {len(keys)} keys from Redis")
await self._redis_delete(keys)
except Exception as e:
logger.error(f"Error deleting ids: {e}")
raise e
return True