skills/classification/evaluation/vectordb.py (64 lines of code) (raw):

import os import numpy as np import voyageai import pickle import json class VectorDB: def __init__(self, api_key=None): if api_key is None: api_key = os.getenv("VOYAGE_API_KEY") self.client = voyageai.Client(api_key=api_key) self.embeddings = [] self.metadata = [] self.query_cache = {} self.db_path = "../data/vector_db.pkl" def load_data(self, data): # Check if the vector database is already loaded if self.embeddings and self.metadata: print("Vector database is already loaded. Skipping data loading.") return # Check if vector_db.pkl exists if os.path.exists(self.db_path): print("Loading vector database from disk.") self.load_db() return texts = [item["text"] for item in data] # Embed more than 128 documents with a for loop batch_size = 128 result = [ self.client.embed( texts[i : i + batch_size], model="voyage-2" ).embeddings for i in range(0, len(texts), batch_size) ] # Flatten the embeddings self.embeddings = [embedding for batch in result for embedding in batch] self.metadata = [item for item in data] # Save the vector database to disk print("Vector database loaded and saved.") def search(self, query, k=5, similarity_threshold=0.85): query_embedding = None if query in self.query_cache: query_embedding = self.query_cache[query] else: query_embedding = self.client.embed([query], model="voyage-2").embeddings[0] self.query_cache[query] = query_embedding if not self.embeddings: raise ValueError("No data loaded in the vector database.") similarities = np.dot(self.embeddings, query_embedding) top_indices = np.argsort(similarities)[::-1] top_examples = [] for idx in top_indices: if similarities[idx] >= similarity_threshold: example = { "metadata": self.metadata[idx], "similarity": similarities[idx], } top_examples.append(example) if len(top_examples) >= k: break return top_examples def load_db(self): if not os.path.exists(self.db_path): raise ValueError("Vector database file not found. Use load_data to create a new database.") with open(self.db_path, "rb") as file: data = pickle.load(file) self.embeddings = data["embeddings"] self.metadata = data["metadata"] self.query_cache = json.loads(data["query_cache"])