in evaluation_pipeline/retrieval.py [0:0]
def run_history_in_vector_db(row_limit, history_file_path, golden_set_file_path):
browsing_history = process_history(row_limit, history_file_path=history_file_path)
# create vector DB
db = create_db()
# load in history for joining later
load_history_in_db(db, browsing_history)
# if a golden set is not provided, assume it's with the history
if not golden_set_file_path:
query_ids = {query: str(hash(query)) for query in browsing_history['search_query'].unique()}
browsing_history['query_id'] = browsing_history['search_query'].map(query_ids)
ground_truth = create_ground_truth(browsing_history, query_ids)
else:
print("Getting doc ids for history")
ground_truth, query_ids, ground_truth_urls = load_ground_truth_from_golden(db, golden_df_file_path=golden_set_file_path)
# create embeddings for candidate models
print("Generating Embeddings")
try:
path = f"data/embeddings_dict_{row_limit}.pkl"
# path = f"/Users/rebeccahadi/Documents/search-your-history-poc/data/embeddings_dict_{row_limit}.pkl"
with open(path, "rb") as f:
embeddings_dict = pickle.load(f)
sizes_path = f"data/embeddings_sizes_{row_limit}.pkl"
# sizes_path = f"/Users/rebeccahadi/Documents/search-your-history-poc/data/embeddings_sizes_{row_limit}.pkl"
with open(sizes_path, "rb") as f:
embeddings_sizes = pickle.load(f)
except:
embeddings_dict, embeddings_sizes = create_embeddings(row_limit, browsing_history, embeddings_model_dict=EMBEDDING_MODELS_DICT)
# loop through each model/embedding type and store in db
for model_name in embeddings_dict.keys():
model_name_normalized = model_name.replace("/","_").replace("-","_").replace(".","_")
# create table for embeddings for model
create_embeddings_table_in_vector_db(db, model_name, embeddings_sizes=embeddings_sizes, embeddings_dict=embeddings_dict)
table_size = get_table_size(db, table_name=model_name_normalized)
logging.info(f"{model_name_normalized} table size: {table_size}")
total_db_size_human_readable = format_size(table_size)
logging.info(f"Table size {model_name_normalized}: {total_db_size_human_readable}")
return query_ids, db, ground_truth, ground_truth_urls