skills/classification/evaluation/vectordb.py (64 lines of code) (raw):
import os
import numpy as np
import voyageai
import pickle
import json
class VectorDB:
def __init__(self, api_key=None):
if api_key is None:
api_key = os.getenv("VOYAGE_API_KEY")
self.client = voyageai.Client(api_key=api_key)
self.embeddings = []
self.metadata = []
self.query_cache = {}
self.db_path = "../data/vector_db.pkl"
def load_data(self, data):
# Check if the vector database is already loaded
if self.embeddings and self.metadata:
print("Vector database is already loaded. Skipping data loading.")
return
# Check if vector_db.pkl exists
if os.path.exists(self.db_path):
print("Loading vector database from disk.")
self.load_db()
return
texts = [item["text"] for item in data]
# Embed more than 128 documents with a for loop
batch_size = 128
result = [
self.client.embed(
texts[i : i + batch_size],
model="voyage-2"
).embeddings
for i in range(0, len(texts), batch_size)
]
# Flatten the embeddings
self.embeddings = [embedding for batch in result for embedding in batch]
self.metadata = [item for item in data]
# Save the vector database to disk
print("Vector database loaded and saved.")
def search(self, query, k=5, similarity_threshold=0.85):
query_embedding = None
if query in self.query_cache:
query_embedding = self.query_cache[query]
else:
query_embedding = self.client.embed([query], model="voyage-2").embeddings[0]
self.query_cache[query] = query_embedding
if not self.embeddings:
raise ValueError("No data loaded in the vector database.")
similarities = np.dot(self.embeddings, query_embedding)
top_indices = np.argsort(similarities)[::-1]
top_examples = []
for idx in top_indices:
if similarities[idx] >= similarity_threshold:
example = {
"metadata": self.metadata[idx],
"similarity": similarities[idx],
}
top_examples.append(example)
if len(top_examples) >= k:
break
return top_examples
def load_db(self):
if not os.path.exists(self.db_path):
raise ValueError("Vector database file not found. Use load_data to create a new database.")
with open(self.db_path, "rb") as file:
data = pickle.load(file)
self.embeddings = data["embeddings"]
self.metadata = data["metadata"]
self.query_cache = json.loads(data["query_cache"])