datastore/providers/chroma_datastore.py (199 lines of code) (raw):
"""
Chroma datastore support for the ChatGPT retrieval plugin.
Consult the Chroma docs and GitHub repo for more information:
- https://docs.trychroma.com/usage-guide?lang=py
- https://github.com/chroma-core/chroma
- https://www.trychroma.com/
"""
import os
from datetime import datetime
from typing import Dict, List, Optional
import chromadb
from datastore.datastore import DataStore
from models.models import (
Document,
DocumentChunk,
DocumentChunkMetadata,
DocumentChunkWithScore,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
Source,
)
from services.chunks import get_document_chunks
CHROMA_IN_MEMORY = os.environ.get("CHROMA_IN_MEMORY", "True")
CHROMA_PERSISTENCE_DIR = os.environ.get("CHROMA_PERSISTENCE_DIR", "openai")
CHROMA_HOST = os.environ.get("CHROMA_HOST", "http://127.0.0.1")
CHROMA_PORT = os.environ.get("CHROMA_PORT", "8000")
CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "openaiembeddings")
class ChromaDataStore(DataStore):
def __init__(
self,
in_memory: bool = CHROMA_IN_MEMORY, # type: ignore
persistence_dir: Optional[str] = CHROMA_PERSISTENCE_DIR,
collection_name: str = CHROMA_COLLECTION,
host: str = CHROMA_HOST,
port: str = CHROMA_PORT,
client: Optional[chromadb.Client] = None,
):
if client:
self._client = client
else:
if in_memory:
settings = (
chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=persistence_dir,
)
if persistence_dir
else chromadb.config.Settings()
)
self._client = chromadb.Client(settings=settings)
else:
self._client = chromadb.Client(
settings=chromadb.config.Settings(
chroma_api_impl="rest",
chroma_server_host=host,
chroma_server_http_port=port,
)
)
self._collection = self._client.get_or_create_collection(
name=collection_name,
embedding_function=None,
)
async def upsert(
self, documents: List[Document], chunk_token_size: Optional[int] = None
) -> List[str]:
"""
Takes in a list of documents and inserts them into the database. If an id already exists, the document is updated.
Return a list of document ids.
"""
chunks = get_document_chunks(documents, chunk_token_size)
# Chroma has a true upsert, so we don't need to delete first
return await self._upsert(chunks)
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of list of document chunks and inserts them into the database.
Return a list of document ids.
"""
self._collection.upsert(
ids=[chunk.id for chunk_list in chunks.values() for chunk in chunk_list],
embeddings=[
chunk.embedding
for chunk_list in chunks.values()
for chunk in chunk_list
],
documents=[
chunk.text for chunk_list in chunks.values() for chunk in chunk_list
],
metadatas=[
self._process_metadata_for_storage(chunk.metadata)
for chunk_list in chunks.values()
for chunk in chunk_list
],
)
return list(chunks.keys())
def _where_from_query_filter(self, query_filter: DocumentMetadataFilter) -> Dict:
output = {
k: v
for (k, v) in query_filter.dict().items()
if v is not None and k != "start_date" and k != "end_date" and k != "source"
}
if query_filter.source:
output["source"] = query_filter.source.value
if query_filter.start_date and query_filter.end_date:
output["$and"] = [
{
"created_at": {
"$gte": int(
datetime.fromisoformat(query_filter.start_date).timestamp()
)
}
},
{
"created_at": {
"$lte": int(
datetime.fromisoformat(query_filter.end_date).timestamp()
)
}
},
]
elif query_filter.start_date:
output["created_at"] = {
"$gte": int(datetime.fromisoformat(query_filter.start_date).timestamp())
}
elif query_filter.end_date:
output["created_at"] = {
"$lte": int(datetime.fromisoformat(query_filter.end_date).timestamp())
}
return output
def _process_metadata_for_storage(self, metadata: DocumentChunkMetadata) -> Dict:
stored_metadata = {}
if metadata.source:
stored_metadata["source"] = metadata.source.value
if metadata.source_id:
stored_metadata["source_id"] = metadata.source_id
if metadata.url:
stored_metadata["url"] = metadata.url
if metadata.created_at:
stored_metadata["created_at"] = int(
datetime.fromisoformat(metadata.created_at).timestamp()
)
if metadata.author:
stored_metadata["author"] = metadata.author
if metadata.document_id:
stored_metadata["document_id"] = metadata.document_id
return stored_metadata
def _process_metadata_from_storage(self, metadata: Dict) -> DocumentChunkMetadata:
return DocumentChunkMetadata(
source=Source(metadata["source"]) if "source" in metadata else None,
source_id=metadata.get("source_id", None),
url=metadata.get("url", None),
created_at=datetime.fromtimestamp(metadata["created_at"]).isoformat()
if "created_at" in metadata
else None,
author=metadata.get("author", None),
document_id=metadata.get("document_id", None),
)
async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores.
"""
results = [
self._collection.query(
query_embeddings=[query.embedding],
include=["documents", "distances", "metadatas"], # embeddings
n_results=min(query.top_k, self._collection.count()), # type: ignore
where=(
self._where_from_query_filter(query.filter) if query.filter else {}
),
)
for query in queries
]
output = []
for query, result in zip(queries, results):
inner_results = []
(ids,) = result["ids"]
# (embeddings,) = result["embeddings"]
(documents,) = result["documents"]
(metadatas,) = result["metadatas"]
(distances,) = result["distances"]
for id_, text, metadata, distance in zip(
ids,
documents,
metadatas,
distances, # embeddings (https://github.com/openai/chatgpt-retrieval-plugin/pull/59#discussion_r1154985153)
):
inner_results.append(
DocumentChunkWithScore(
id=id_,
text=text,
metadata=self._process_metadata_from_storage(metadata),
# embedding=embedding,
score=distance,
)
)
output.append(QueryResult(query=query.query, results=inner_results))
return output
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes vectors by ids, filter, or everything in the datastore.
Multiple parameters can be used at once.
Returns whether the operation was successful.
"""
if delete_all:
self._collection.delete()
return True
if ids and len(ids) > 0:
if len(ids) > 1:
where_clause = {"$or": [{"document_id": id_} for id_ in ids]}
else:
(id_,) = ids
where_clause = {"document_id": id_}
if filter:
where_clause = {
"$and": [self._where_from_query_filter(filter), where_clause]
}
elif filter:
where_clause = self._where_from_query_filter(filter)
self._collection.delete(where=where_clause)
return True