evaluation_pipeline/retrieval.py (303 lines of code) (raw):
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import tracemalloc
import time
import pandas as pd
import numpy as np
import sqlite3
import sqlite_vec
from typing import List
import struct
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import os
import sys
import argparse
import pickle
# Add the parent directory of `evaluation_pipeline` to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.constants import EMBEDDING_MODELS_DICT
from src.feature_extractor import FeatureExtractor
def log_performance(func):
def wrapper(*args, **kwargs):
start_time = time.time()
tracemalloc.start()
result = func(*args, **kwargs)
current, peak = tracemalloc.get_traced_memory()
end_time = time.time()
logging.info(f"Function: {func.__name__}")
logging.info(f"Execution Time: {end_time - start_time:.2f} seconds")
logging.info(f"Current Memory Usage: {current / 10**6:.2f} MB")
logging.info(f"Peak Memory Usage: {peak / 10**6:.2f} MB")
tracemalloc.stop()
return result
return wrapper
def format_size(size_in_bytes):
"""Convert bytes to a human-readable format (KB, MB, GB, etc.)."""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_in_bytes < 1024:
return f"{size_in_bytes:.2f} {unit}"
size_in_bytes /= 1024
def process_golden_queries(golden_query_file_path):
golden_df = pd.read_csv(golden_query_file_path)
query_ids = {query: str(hash(query)) for query in golden_df['search_query'].unique()}
return query_ids
@log_performance
def process_history(row_limit, history_file_path):
browsing_history = pd.read_csv(history_file_path).head(row_limit)
browsing_history['last_visit_date'] = pd.to_datetime(browsing_history['last_visit_date'], unit='us')
# fill empty last_visit_date with default value "1970-01-01"
browsing_history['last_visit_date'] = browsing_history['last_visit_date'].fillna(pd.to_datetime("1970-01-01"))
browsing_history['combined_text'] = browsing_history['title'].fillna('') + " " + browsing_history['description'].fillna('')
browsing_history['combined_text_url'] = browsing_history['title'].fillna('') + " " + browsing_history['description'].fillna('') + browsing_history['url'].fillna('')
browsing_history = browsing_history.loc[browsing_history['combined_text'] != ''].reset_index(drop=True)
print(len(browsing_history))
return browsing_history
@log_performance
def create_embeddings(row_limit, history, embeddings_model_dict):
texts = history['combined_text'].values.tolist()
embeddings_dict = {}
embeddings_sizes = {}
for model in embeddings_model_dict.keys():
if model == 'nomic-ai/nomic-embed-text-v1.5':
prefix = 'search_document: '
texts = [prefix + text for text in texts]
fe = FeatureExtractor(embeddings_model_dict, model_name=model)
embeddings_dict[model] = fe.get_embeddings(texts)
print(model, embeddings_dict[model].shape)
embeddings_sizes[model] = embeddings_dict[model].shape[1]
with open(f"data/embeddings_dict_{row_limit}.pkl", "wb") as f:
pickle.dump(embeddings_dict, f)
with open(f"data/embeddings_sizes_{row_limit}.pkl", "wb") as f:
pickle.dump(embeddings_sizes, f)
return embeddings_dict, embeddings_sizes
@log_performance
def create_db():
# vector database
#db = sqlite3.connect(":memory:")
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)
sqlite_version, vec_version = db.execute(
"select sqlite_version(), vec_version()"
).fetchone()
print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")
return db
def serialize_f32(vector: List[float]) -> bytes:
"""serializes a list of floats into a compact "raw bytes" format"""
return struct.pack("%sf" % len(vector), *vector)
@log_performance
def create_embeddings_table_in_vector_db(db, model_name, embeddings_sizes, embeddings_dict):
print("creating table")
EMBEDDING_SIZE = embeddings_sizes[model_name]
items = []
for idx, vec in enumerate(embeddings_dict[model_name]):
items.append((idx, list(vec)))
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
db.execute(f"CREATE VIRTUAL TABLE vec_items_{model_name_normalized} USING vec0(embedding float[{EMBEDDING_SIZE}])")
with db:
for item in items:
db.execute(
f"INSERT INTO vec_items_{model_name_normalized}(rowid, embedding) VALUES (?, ?)",
[item[0], serialize_f32(item[1])],
)
return db
# retrieval
@log_performance
def query_and_result(fe, query, db, model_name, threshold, k):
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
if model_name == 'nomic-ai/nomic-embed-text-v1.5':
query = 'search_query: ' + query
query_embedding = fe.get_embeddings([query])[0]
# using cosine distance
rows = db.execute(
f"""
SELECT
a.rowid,
vec_distance_cosine(embedding, ?) AS cosine_distance
FROM vec_items_{model_name_normalized} a
inner join search_data b
on a.rowid = b. rowid
where b.url not like '%google.com/search?%'
ORDER BY cosine_distance
LIMIT {k}
""",
[serialize_f32(query_embedding)],
).fetchall()
results = []
for row in rows:
rowid = row[0] # Get the row ID
distance = row[1] # Get the cosine distance
# Skip rows where distance > threshold
if distance > threshold:
print("Distance threshold exceeded")
continue
print("doing search now")
# Step 2: Query additional details for the matching row from search_data
res = db.execute(
"""
SELECT rowid, title, url, combined_text
FROM search_data
WHERE rowid = ?
""",
(rowid,)
).fetchone()
# Add the result to the results list
if res:
results.append({
"id": res[0],
"title": res[1],
"url": res[2],
"combined_text": res[3],
"distance": distance,
})
return results
def get_table_size(connection, table_name):
"""Get the approximate size of a table in SQLite."""
with connection:
cursor = connection.cursor()
# Get page size and page count
cursor.execute("PRAGMA page_count;")
page_count = cursor.fetchone()[0]
cursor.execute("PRAGMA page_size;")
page_size = cursor.fetchone()[0]
total_db_size = page_count * page_size
# If dbstat is available, use it for more precision
try:
cursor.execute(f"""
SELECT SUM(pgsize) AS table_size
FROM dbstat
WHERE name = ?;
""", (table_name,))
result = cursor.fetchone()
if result and result[0]:
return result[0]
except sqlite3.DatabaseError:
print("dbstat is not available; estimating based on database size.")
# Fallback to database size as an estimate
return total_db_size
@log_performance
def load_history_in_db(db, browsing_history):
db.execute('''CREATE TABLE IF NOT EXISTS search_data (
url TEXT,
title TEXT,
description TEXT,
combined_text TEXT
)
''')
for idx, row in browsing_history.iterrows():
db.execute("""
INSERT INTO search_data (rowid, url, title, description, combined_text)
VALUES (?, ?, ?, ?, ?)
""", (idx, row['url'], row['title'], row['description'], row['combined_text'])
)
browsing_history.to_sql("full_history", db, if_exists="replace", index=True) # Creates the table automatically
def create_ground_truth(history, query_ids):
queries = []
relevant_ids = []
for query in history['search_query'].unique():
query_id = query_ids[query]
queries.append(query_id)
ground_truths = []
for i, row in history.iterrows():
relevant = 1 if query == row['search_query'] else 0
if relevant == 1:
ground_truths.append(int(i))
relevant_ids.append(ground_truths)
return dict(zip(queries, relevant_ids))
def load_ground_truth_from_golden(db, golden_df_file_path):
golden_df = pd.read_csv(golden_df_file_path)
query_ids = {query: str(hash(query)) for query in golden_df['search_query'].unique()}
db.execute('''CREATE TABLE IF NOT EXISTS ground_truth (
search_query TEXT,
url TEXT
)
''')
with db:
for _, row in golden_df.iterrows():
db.execute("""
INSERT INTO ground_truth (search_query, url)
VALUES (?, ?)
""", ( row['search_query'], row['url'])
)
# join to search query to get doc ID for ground truth URL
results = db.execute(
'''SELECT a.search_query, a.url, b.rowid
FROM ground_truth a
left join search_data b
on a.url = b.url'''
).fetchall()
ground_truth = {}
ground_truth_urls = {}
for query, url, id_ in results:
query_id = query_ids[query]
if query_id not in ground_truth:
ground_truth[query_id] = []
ground_truth_urls[query_id] = []
ground_truth[query_id].append(id_)
ground_truth_urls[query_id].append(url)
return ground_truth, query_ids, ground_truth_urls
@log_performance
def run_history_in_vector_db(row_limit, history_file_path, golden_set_file_path):
browsing_history = process_history(row_limit, history_file_path=history_file_path)
# create vector DB
db = create_db()
# load in history for joining later
load_history_in_db(db, browsing_history)
# if a golden set is not provided, assume it's with the history
if not golden_set_file_path:
query_ids = {query: str(hash(query)) for query in browsing_history['search_query'].unique()}
browsing_history['query_id'] = browsing_history['search_query'].map(query_ids)
ground_truth = create_ground_truth(browsing_history, query_ids)
else:
print("Getting doc ids for history")
ground_truth, query_ids, ground_truth_urls = load_ground_truth_from_golden(db, golden_df_file_path=golden_set_file_path)
# create embeddings for candidate models
print("Generating Embeddings")
try:
path = f"data/embeddings_dict_{row_limit}.pkl"
# path = f"/Users/rebeccahadi/Documents/search-your-history-poc/data/embeddings_dict_{row_limit}.pkl"
with open(path, "rb") as f:
embeddings_dict = pickle.load(f)
sizes_path = f"data/embeddings_sizes_{row_limit}.pkl"
# sizes_path = f"/Users/rebeccahadi/Documents/search-your-history-poc/data/embeddings_sizes_{row_limit}.pkl"
with open(sizes_path, "rb") as f:
embeddings_sizes = pickle.load(f)
except:
embeddings_dict, embeddings_sizes = create_embeddings(row_limit, browsing_history, embeddings_model_dict=EMBEDDING_MODELS_DICT)
# loop through each model/embedding type and store in db
for model_name in embeddings_dict.keys():
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
# create table for embeddings for model
create_embeddings_table_in_vector_db(db, model_name, embeddings_sizes=embeddings_sizes, embeddings_dict=embeddings_dict)
table_size = get_table_size(db, table_name=model_name_normalized)
logging.info(f"{model_name_normalized} table size: {table_size}")
total_db_size_human_readable = format_size(table_size)
logging.info(f"Table size {model_name_normalized}: {total_db_size_human_readable}")
return query_ids, db, ground_truth, ground_truth_urls
@log_performance
def run_retrieval(fe, query_ids, db, model_name, threshold, k):
# loop through each query
retreival_dict = {}
query_lookup = {}
for query in query_ids.keys():
# for later identifying query
query_id = query_ids[query]
query_lookup[query_id] = query
# perform retrieval
results = query_and_result(fe, query, db=db,
model_name=model_name, threshold=threshold, k=k)
retreival_dict[query_id] = results
return retreival_dict, query_lookup
def convert_dict_to_df(retrieval_dict, query_lookup, ground_truth, ground_truth_urls, model_name, k):
rows = []
for query_id, retrievals in retrieval_dict.items():
# Flatten each retrieval into a single row with column names based on retrieval index
row = {'query_id': str(query_id)}
retrieved_ids = [] # List to collect all retrieved IDs
retrieved_distances = [] # collect the distances
for i, retrieval in enumerate(retrievals, start=1):
row[f'retrieval_{i}_id'] = retrieval.get('id')
row[f'retrieval_{i}_title'] = retrieval.get('title')
row[f'retrieval_{i}_url'] = retrieval.get('url')
row[f'retrieval_{i}_combined_text'] = retrieval.get('combined_text')
row[f'retrieval_{i}_distance'] = retrieval.get('distance')
retrieved_ids.append(retrieval.get('id'))
retrieved_distances.append(retrieval.get('distance'))
# Collect the ID for the list
row['retrieved_ids'] = retrieved_ids
row['retrieved_distances'] = retrieved_distances
row['model_name'] = model_name
row['query'] = query_lookup[query_id]
row['relevant_docs'] = ground_truth[query_id]
row['relevant_urls'] = ground_truth_urls[query_id]
row['k'] = k
rows.append(row)
df = pd.DataFrame(rows)
return df
def main(model_name, k, threshold, history_file_path, golden_path=None, row_limit=100):
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
# Configure logging
logging.basicConfig(
filename=f"performance_{model_name_normalized}.log",
level=logging.INFO,
format="%(asctime)s - %(message)s"
)
query_ids, db, ground_truth, ground_truth_urls = run_history_in_vector_db(row_limit, history_file_path=history_file_path, golden_set_file_path=golden_path)
fe = FeatureExtractor(EMBEDDING_MODELS_DICT, model_name=model_name)
retrieval_results, query_lookup = run_retrieval(fe, query_ids, db, model_name, threshold, k)
# reshape & save to df and csv
df = convert_dict_to_df(retrieval_dict=retrieval_results, query_lookup=query_lookup, ground_truth=ground_truth, ground_truth_urls=ground_truth_urls, model_name=model_name, k=k)
time_stamp = int(time.time())
df.to_csv(f"results/{model_name_normalized}_results.csv", index=False)
return db, retrieval_results, df
if __name__ == "__main__":
# Create the argument parser
parser = argparse.ArgumentParser(description="Run the retrieval pipeline with specified parameters.")
# Add arguments
parser.add_argument("history_file_path", type=str, help="Path to the browsing history file.")
parser.add_argument("--model_name", type=str,default='Xenova/all-MiniLM-L6-v2', help="Name of the model to use.")
parser.add_argument("--k", type=int, default=2, help="Top-K results to retrieve.")
parser.add_argument("--threshold", type=float, default=10.0, help="Threshold for retrieval.")
parser.add_argument("--golden_path", type=str, default=None, help="Path to the golden query set file (optional).")
parser.add_argument("--row_limit", type=int, default=100, help="Whether to limit rows from browsing history upon load")
# Parse arguments
args = parser.parse_args()
# Call the main function with parsed arguments
main(
model_name=args.model_name,
k=args.k,
threshold=args.threshold,
history_file_path=args.history_file_path,
golden_path=args.golden_path,
row_limit=args.row_limit
)