datastore/providers/postgres_datastore.py (100 lines of code) (raw):

import os from typing import Any, List from datetime import datetime import numpy as np from psycopg2 import connect from psycopg2.extras import DictCursor from pgvector.psycopg2 import register_vector from services.date import to_unix_timestamp from datastore.providers.pgvector_datastore import PGClient, PgVectorDataStore from models.models import ( DocumentMetadataFilter, ) PG_HOST = os.environ.get("PG_HOST", "localhost") PG_PORT = int(os.environ.get("PG_PORT", 5432)) PG_DB = os.environ.get("PG_DB", "postgres") PG_USER = os.environ.get("PG_USER", "postgres") PG_PASSWORD = os.environ.get("PG_PASSWORD", "postgres") # class that implements the DataStore interface for Postgres Datastore provider class PostgresDataStore(PgVectorDataStore): def create_db_client(self): return PostgresClient() class PostgresClient(PGClient): def __init__(self) -> None: super().__init__() self.client = connect( dbname=PG_DB, user=PG_USER, password=PG_PASSWORD, host=PG_HOST, port=PG_PORT ) register_vector(self.client) def __del__(self): # close the connection when the client is destroyed self.client.close() async def upsert(self, table: str, json: dict[str, Any]): """ Takes in a list of documents and inserts them into the table. """ with self.client.cursor() as cur: if not json.get("created_at"): json["created_at"] = datetime.now() json["embedding"] = np.array(json["embedding"]) cur.execute( f"INSERT INTO {table} (id, content, embedding, document_id, source, source_id, url, author, created_at) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET content = %s, embedding = %s, document_id = %s, source = %s, source_id = %s, url = %s, author = %s, created_at = %s", ( json["id"], json["content"], json["embedding"], json["document_id"], json["source"], json["source_id"], json["url"], json["author"], json["created_at"], json["content"], json["embedding"], json["document_id"], json["source"], json["source_id"], json["url"], json["author"], json["created_at"], ), ) self.client.commit() async def rpc(self, function_name: str, params: dict[str, Any]): """ Calls a stored procedure in the database with the given parameters. """ data = [] params["in_embedding"] = np.array(params["in_embedding"]) with self.client.cursor(cursor_factory=DictCursor) as cur: cur.callproc(function_name, params) rows = cur.fetchall() self.client.commit() for row in rows: row["created_at"] = to_unix_timestamp(row["created_at"]) data.append(dict(row)) return data async def delete_like(self, table: str, column: str, pattern: str): """ Deletes rows in the table that match the pattern. """ with self.client.cursor() as cur: cur.execute( f"DELETE FROM {table} WHERE {column} LIKE %s", (f"%{pattern}%",), ) self.client.commit() async def delete_in(self, table: str, column: str, ids: List[str]): """ Deletes rows in the table that match the ids. """ with self.client.cursor() as cur: cur.execute( f"DELETE FROM {table} WHERE {column} IN %s", (tuple(ids),), ) self.client.commit() async def delete_by_filters(self, table: str, filter: DocumentMetadataFilter): """ Deletes rows in the table that match the filter. """ filters = "WHERE" if filter.document_id: filters += f" document_id = '{filter.document_id}' AND" if filter.source: filters += f" source = '{filter.source}' AND" if filter.source_id: filters += f" source_id = '{filter.source_id}' AND" if filter.author: filters += f" author = '{filter.author}' AND" if filter.start_date: filters += f" created_at >= '{filter.start_date}' AND" if filter.end_date: filters += f" created_at <= '{filter.end_date}' AND" filters = filters[:-4] with self.client.cursor() as cur: cur.execute(f"DELETE FROM {table} {filters}") self.client.commit()