datastore/providers/qdrant_datastore.py (236 lines of code) (raw):
import os
import uuid
from typing import Dict, List, Optional
from grpc._channel import _InactiveRpcError
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from datastore.datastore import DataStore
from models.models import (
DocumentChunk,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
DocumentChunkWithScore,
)
from qdrant_client.http import models as rest
import qdrant_client
from services.date import to_unix_timestamp
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost")
QDRANT_PORT = os.environ.get("QDRANT_PORT", "6333")
QDRANT_GRPC_PORT = os.environ.get("QDRANT_GRPC_PORT", "6334")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
QDRANT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "document_chunks")
EMBEDDING_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256))
class QdrantDataStore(DataStore):
UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d")
def __init__(
self,
collection_name: Optional[str] = None,
vector_size: int = EMBEDDING_DIMENSION,
distance: str = "Cosine",
recreate_collection: bool = False,
):
"""
Args:
collection_name: Name of the collection to be used
vector_size: Size of the embedding stored in a collection
distance:
Any of "Cosine" / "Euclid" / "Dot". Distance function to measure
similarity
"""
self.client = qdrant_client.QdrantClient(
url=QDRANT_URL,
port=int(QDRANT_PORT),
grpc_port=int(QDRANT_GRPC_PORT),
api_key=QDRANT_API_KEY,
prefer_grpc=True,
timeout=10,
)
self.collection_name = collection_name or QDRANT_COLLECTION
# Set up the collection so the points might be inserted or queried
self._set_up_collection(vector_size, distance, recreate_collection)
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of document chunks and inserts them into the database.
Return a list of document ids.
"""
points = [
self._convert_document_chunk_to_point(chunk)
for _, chunks in chunks.items()
for chunk in chunks
]
self.client.upsert(
collection_name=self.collection_name,
points=points, # type: ignore
wait=True,
)
return list(chunks.keys())
async def _query(
self,
queries: List[QueryWithEmbedding],
) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores.
"""
search_requests = [
self._convert_query_to_search_request(query) for query in queries
]
results = self.client.search_batch(
collection_name=self.collection_name,
requests=search_requests,
)
return [
QueryResult(
query=query.query,
results=[
self._convert_scored_point_to_document_chunk_with_score(point)
for point in result
],
)
for query, result in zip(queries, results)
]
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes vectors by ids, filter, or everything in the datastore.
Returns whether the operation was successful.
"""
if ids is None and filter is None and delete_all is None:
raise ValueError(
"Please provide one of the parameters: ids, filter or delete_all."
)
if delete_all:
points_selector = rest.Filter()
else:
points_selector = self._convert_metadata_filter_to_qdrant_filter(
filter, ids
)
response = self.client.delete(
collection_name=self.collection_name,
points_selector=points_selector, # type: ignore
)
return "COMPLETED" == response.status
def _convert_document_chunk_to_point(
self, document_chunk: DocumentChunk
) -> rest.PointStruct:
created_at = (
to_unix_timestamp(document_chunk.metadata.created_at)
if document_chunk.metadata.created_at is not None
else None
)
return rest.PointStruct(
id=self._create_document_chunk_id(document_chunk.id),
vector=document_chunk.embedding, # type: ignore
payload={
"id": document_chunk.id,
"text": document_chunk.text,
"metadata": document_chunk.metadata.dict(),
"created_at": created_at,
},
)
def _create_document_chunk_id(self, external_id: Optional[str]) -> str:
if external_id is None:
return uuid.uuid4().hex
return uuid.uuid5(self.UUID_NAMESPACE, external_id).hex
def _convert_query_to_search_request(
self, query: QueryWithEmbedding
) -> rest.SearchRequest:
return rest.SearchRequest(
vector=query.embedding,
filter=self._convert_metadata_filter_to_qdrant_filter(query.filter),
limit=query.top_k, # type: ignore
with_payload=True,
with_vector=False,
)
def _convert_metadata_filter_to_qdrant_filter(
self,
metadata_filter: Optional[DocumentMetadataFilter] = None,
ids: Optional[List[str]] = None,
) -> Optional[rest.Filter]:
if metadata_filter is None and ids is None:
return None
must_conditions, should_conditions = [], []
# Filtering by document ids
if ids and len(ids) > 0:
for document_id in ids:
should_conditions.append(
rest.FieldCondition(
key="metadata.document_id",
match=rest.MatchValue(value=document_id),
)
)
# Equality filters for the payload attributes
if metadata_filter:
meta_attributes_keys = {
"document_id": "metadata.document_id",
"source": "metadata.source",
"source_id": "metadata.source_id",
"author": "metadata.author",
}
for meta_attr_name, payload_key in meta_attributes_keys.items():
attr_value = getattr(metadata_filter, meta_attr_name)
if attr_value is None:
continue
must_conditions.append(
rest.FieldCondition(
key=payload_key, match=rest.MatchValue(value=attr_value)
)
)
# Date filters use range filtering
start_date = metadata_filter.start_date
end_date = metadata_filter.end_date
if start_date or end_date:
gte_filter = (
to_unix_timestamp(start_date) if start_date is not None else None
)
lte_filter = (
to_unix_timestamp(end_date) if end_date is not None else None
)
must_conditions.append(
rest.FieldCondition(
key="created_at",
range=rest.Range(
gte=gte_filter,
lte=lte_filter,
),
)
)
if 0 == len(must_conditions) and 0 == len(should_conditions):
return None
return rest.Filter(must=must_conditions, should=should_conditions)
def _convert_scored_point_to_document_chunk_with_score(
self, scored_point: rest.ScoredPoint
) -> DocumentChunkWithScore:
payload = scored_point.payload or {}
return DocumentChunkWithScore(
id=payload.get("id"),
text=scored_point.payload.get("text"), # type: ignore
metadata=scored_point.payload.get("metadata"), # type: ignore
embedding=scored_point.vector, # type: ignore
score=scored_point.score,
)
def _set_up_collection(
self, vector_size: int, distance: str, recreate_collection: bool
):
distance = rest.Distance[distance.upper()]
if recreate_collection:
self._recreate_collection(distance, vector_size)
try:
collection_info = self.client.get_collection(self.collection_name)
current_distance = collection_info.config.params.vectors.distance # type: ignore
current_vector_size = collection_info.config.params.vectors.size # type: ignore
if current_distance != distance:
raise ValueError(
f"Collection '{self.collection_name}' already exists in Qdrant, "
f"but it is configured with a similarity '{current_distance.name}'. "
f"If you want to use that collection, but with a different "
f"similarity, please set `recreate_collection=True` argument."
)
if current_vector_size != vector_size:
raise ValueError(
f"Collection '{self.collection_name}' already exists in Qdrant, "
f"but it is configured with a vector size '{current_vector_size}'. "
f"If you want to use that collection, but with a different "
f"vector size, please set `recreate_collection=True` argument."
)
except (UnexpectedResponse, _InactiveRpcError):
self._recreate_collection(distance, vector_size)
def _recreate_collection(self, distance: rest.Distance, vector_size: int):
self.client.recreate_collection(
self.collection_name,
vectors_config=rest.VectorParams(
size=vector_size,
distance=distance,
),
)
# Create the payload index for the document_id metadata attribute, as it is
# used to delete the document related entries
self.client.create_payload_index(
self.collection_name,
field_name="metadata.document_id",
field_type=PayloadSchemaType.KEYWORD,
)
# Create the payload index for the created_at attribute, to make the lookup
# by range filters faster
self.client.create_payload_index(
self.collection_name,
field_name="created_at",
field_schema=PayloadSchemaType.INTEGER,
)