datastore/providers/pgvector_datastore.py (129 lines of code) (raw):

from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional from datetime import datetime from loguru import logger from services.date import to_unix_timestamp from datastore.datastore import DataStore from models.models import ( DocumentChunk, DocumentChunkMetadata, DocumentMetadataFilter, QueryResult, QueryWithEmbedding, DocumentChunkWithScore, ) # interface for Postgres client to implement pg based Datastore providers class PGClient(ABC): @abstractmethod async def upsert(self, table: str, json: dict[str, Any]) -> None: """ Takes in a list of documents and inserts them into the table. """ raise NotImplementedError @abstractmethod async def rpc(self, function_name: str, params: dict[str, Any]) -> Any: """ Calls a stored procedure in the database with the given parameters. """ raise NotImplementedError @abstractmethod async def delete_like(self, table: str, column: str, pattern: str) -> None: """ Deletes rows in the table that match the pattern. """ raise NotImplementedError @abstractmethod async def delete_in(self, table: str, column: str, ids: List[str]) -> None: """ Deletes rows in the table that match the ids. """ raise NotImplementedError @abstractmethod async def delete_by_filters( self, table: str, filter: DocumentMetadataFilter ) -> None: """ Deletes rows in the table that match the filter. """ raise NotImplementedError # abstract class for Postgres based Datastore providers that implements DataStore interface class PgVectorDataStore(DataStore): def __init__(self): self.client = self.create_db_client() @abstractmethod def create_db_client(self) -> PGClient: """ Create db client, can be accessing postgres database via different APIs. Can be supabase client or psycopg2 based client. Return a client for postgres DB. """ raise NotImplementedError async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: """ Takes in a dict of document_ids to list of document chunks and inserts them into the database. Return a list of document ids. """ for document_id, document_chunks in chunks.items(): for chunk in document_chunks: json = { "id": chunk.id, "content": chunk.text, "embedding": chunk.embedding, "document_id": document_id, "source": chunk.metadata.source, "source_id": chunk.metadata.source_id, "url": chunk.metadata.url, "author": chunk.metadata.author, } if chunk.metadata.created_at: json["created_at"] = ( datetime.fromtimestamp( to_unix_timestamp(chunk.metadata.created_at) ), ) await self.client.upsert("documents", json) return list(chunks.keys()) async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]: """ Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. """ query_results: List[QueryResult] = [] for query in queries: # get the top 3 documents with the highest cosine similarity using rpc function in the database called "match_page_sections" params = { "in_embedding": query.embedding, } if query.top_k: params["in_match_count"] = query.top_k if query.filter: if query.filter.document_id: params["in_document_id"] = query.filter.document_id if query.filter.source: params["in_source"] = query.filter.source.value if query.filter.source_id: params["in_source_id"] = query.filter.source_id if query.filter.author: params["in_author"] = query.filter.author if query.filter.start_date: params["in_start_date"] = datetime.fromtimestamp( to_unix_timestamp(query.filter.start_date) ) if query.filter.end_date: params["in_end_date"] = datetime.fromtimestamp( to_unix_timestamp(query.filter.end_date) ) try: data = await self.client.rpc("match_page_sections", params=params) results: List[DocumentChunkWithScore] = [] for row in data: document_chunk = DocumentChunkWithScore( id=row["id"], text=row["content"], # TODO: add embedding to the response ? # embedding=row["embedding"], score=float(row["similarity"]), metadata=DocumentChunkMetadata( source=row["source"], source_id=row["source_id"], document_id=row["document_id"], url=row["url"], created_at=row["created_at"], author=row["author"], ), ) results.append(document_chunk) query_results.append(QueryResult(query=query.query, results=results)) except Exception as e: logger.error(e) query_results.append(QueryResult(query=query.query, results=[])) return query_results async def delete( self, ids: Optional[List[str]] = None, filter: Optional[DocumentMetadataFilter] = None, delete_all: Optional[bool] = None, ) -> bool: """ Removes vectors by ids, filter, or everything in the datastore. Multiple parameters can be used at once. Returns whether the operation was successful. """ if delete_all: try: await self.client.delete_like("documents", "document_id", "%") except: return False elif ids: try: await self.client.delete_in("documents", "document_id", ids) except: return False elif filter: try: await self.client.delete_by_filters("documents", filter) except: return False return True