notebooks/bin_quantization_approach.ipynb (591 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "d83680ff-b52f-4c4a-a565-62774533f944",
"metadata": {},
"source": [
"Explore binary quantization\n",
"\n",
"https://alexgarcia.xyz/sqlite-vec/guides/binary-quant.html"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "563e4f5f-8c6b-4058-9c2c-00da1329b03c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "335c4f20-f649-4ea4-ae44-9e300718447b",
"metadata": {},
"outputs": [],
"source": [
"import sqlite3\n",
"import sqlite_vec\n",
"from typing import List\n",
"import struct"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "978fda31-9915-42c0-8bd5-ef3e8c95d8f3",
"metadata": {},
"outputs": [],
"source": [
"# Add the project root directory to the Python path\n",
"project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
"sys.path.append(project_root)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a0d86492-e8be-4acf-9635-6e6e54bbb5d4",
"metadata": {},
"outputs": [],
"source": [
"from src.constants import EMBEDDING_MODELS_DICT\n",
"from src.feature_extractor import FeatureExtractor\n",
"from src.metrics import run_traditional_eval"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd1327d6-fe08-4d9f-8030-9a3cbec4cac9",
"metadata": {},
"outputs": [],
"source": [
"!python -V"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "abe464ec-6331-4a58-9b95-95f547edb899",
"metadata": {},
"outputs": [],
"source": [
"!python -m pip freeze| grep sqlite"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0651ef95-2ca5-49eb-89c0-4e49a4641328",
"metadata": {},
"outputs": [],
"source": [
"# !export LDFLAGS=\"-L/opt/homebrew/opt/sqlite/lib\"\n",
"# !export CPPFLAGS=\"-I/opt/homebrew/opt/sqlite/include\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c9b15421-08df-40c2-8837-1ca5cc1acc2f",
"metadata": {},
"outputs": [],
"source": [
"db = sqlite3.connect(\":memory:\")\n",
"db.enable_load_extension(True)\n",
"sqlite_vec.load(db)\n",
"db.enable_load_extension(False)\n",
"\n",
"sqlite_version, vec_version = db.execute(\n",
" \"select sqlite_version(), vec_version()\"\n",
").fetchone()\n",
"print(f\"sqlite_version={sqlite_version}, vec_version={vec_version}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0d71dc0-f5d5-4c69-b11f-fff4dea8e57a",
"metadata": {},
"outputs": [],
"source": [
"res = db.execute(f\"\"\"select vec_quantize_binary(\n",
" '[-0.73, -0.80, 0.12, -0.73, 0.79, -0.11, 0.23, 0.97]'\n",
");\"\"\").fetchall()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bb556110-3b8a-4d59-b042-555d0c6dc5ed",
"metadata": {},
"outputs": [],
"source": [
"res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e94280d9-c129-4dd5-a258-b9be9dda6a61",
"metadata": {},
"outputs": [],
"source": [
"type(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31809d3a-70c4-446d-baae-055d2d8dba1f",
"metadata": {},
"outputs": [],
"source": [
"# int(res[0][0])\n",
"byte_value = res[0][0]\n",
"binary_representation = bin(int.from_bytes(byte_value, \"big\"))\n",
"print(f\"Binary Representation: {binary_representation}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52f6613d-024e-4580-88f0-159665fae6fb",
"metadata": {},
"outputs": [],
"source": [
"row_limit = 10000"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ee5def5-432a-4958-93bc-76b21297be83",
"metadata": {},
"outputs": [],
"source": [
"firefox_conn = sqlite3.connect(\"../data/places.sqlite\") \n",
"firefox_cursor = firefox_conn.cursor()\n",
"\n",
"input_data = firefox_cursor.execute(f\"\"\"\n",
"WITH TOP_FRECENT_PLACES AS\n",
"(SELECT p.url, p.title, COALESCE(p.description, '') AS description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash,\n",
" p.last_visit_date\n",
"FROM moz_places p\n",
"WHERE p.title NOTNULL\n",
"AND url not like '%google.com/search?%'\n",
"ORDER BY frecency DESC\n",
"LIMIT {row_limit}\n",
") \n",
"\n",
"SELECT * FROM TOP_FRECENT_PLACES;\n",
"\"\"\").fetchall()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "576306ff-6db4-4d17-958b-dcd7800b364f",
"metadata": {},
"outputs": [],
"source": [
"history = pd.read_csv(\"../data/history_output_file.csv\")\n",
"# history = pd.DataFrame(input_data, \n",
"# columns=['url', 'title', 'description', 'place_id', 'frecency', 'origin_id', 'url_hash', 'last_visit_date'])\n",
"history['last_visit_date'] = pd.to_datetime(history['last_visit_date'], unit='us')\n",
"\n",
"# fill empty last_visit_date with default value \"1970-01-01\"\n",
"history['last_visit_date'] = history['last_visit_date'].fillna(pd.to_datetime(\"1970-01-01\"))\n",
"history['combined_text'] = history['title'].fillna('') + \" \" + history['description'].fillna('')\n",
"history = history.loc[history['combined_text'] != ''].reset_index(drop=True).head(row_limit)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "295d6da5-c257-4a6b-ae37-485472824d88",
"metadata": {},
"outputs": [],
"source": [
"history"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b14078a-45d0-4694-b2e0-e6a9a3013841",
"metadata": {},
"outputs": [],
"source": [
"EMBEDDING_MODELS_DICT['Xenova/all-MiniLM-L6-v2']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "770c11b4-6288-44ea-be69-d0adca51c4e3",
"metadata": {},
"outputs": [],
"source": [
"EMBEDDING_MODELS_DICT"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1cb2048-0f42-4cc1-9daa-8b71764ceb10",
"metadata": {},
"outputs": [],
"source": [
"model_name, embeddings_size = 'Xenova/all-MiniLM-L6-v2', 384\n",
"# model_name, embeddings_size = 'nomic-ai/nomic-embed-text-v1.5', 768\n",
"# model_name, embeddings_size = \"Xenova/all-mpnet-base-v2\", 768\n",
"# model_name, embeddings_size = 'Xenova/paraphrase-mpnet-base-v2', 768\n",
"# model_name, embeddings_size = 'Xenova/all-MiniLM-L12-v2', 384\n",
"# model_name, embeddings_size = 'nomic-ai/modernbert-embed-base', 768\n",
"fe = FeatureExtractor(EMBEDDING_MODELS_DICT, model_name=model_name)\n",
"texts = history['combined_text'].values.tolist()\n",
"embeddings = fe.get_embeddings(texts)\n",
"embeddings.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "995cbbdd-b4b0-410e-85c3-b0b87a71b889",
"metadata": {},
"outputs": [],
"source": [
"model_name_normalized = model_name.replace(\"/\",\"_\").replace(\"-\",\"_\").replace(\".\",\"_\")\n",
"\n",
"# Function to convert float vectors to binary format for SQLite\n",
"def serialize_f32_from_np(vector: np.ndarray) -> bytes:\n",
" \"\"\"Serializes a NumPy float32 vector into raw bytes format for SQLite.\"\"\"\n",
" return struct.pack(f\"{len(vector)}f\", *vector.astype(np.float32)) # Convert to float32\n",
"\n",
"\n",
"items = []\n",
"for idx, vec in enumerate(embeddings):\n",
" items.append((idx, vec))\n",
"\n",
"for item in items[:5]:\n",
" print(type(item[1][0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac2fdb82-52c9-48e2-b0ba-445d6b569cca",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "7f2c5192-7a38-4848-a096-e1c76e228956",
"metadata": {
"scrolled": true
},
"source": [
"#### Approach 1 just using the binary quantization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "658fd9ad-074d-4d27-b7a6-2fe3f14eef79",
"metadata": {},
"outputs": [],
"source": [
"db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[{embeddings_size}])\")\n",
"# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[768])\")\n",
"# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[128])\")\n",
"\n",
"with db:\n",
" for idx, vec in enumerate(embeddings):\n",
" db.execute(\n",
" f\"INSERT INTO vec_items_{model_name_normalized}_1(rowid, embedding) VALUES (?, vec_quantize_binary(?))\",\n",
" [idx, serialize_f32_from_np(vec)], # Convert vector to binary format\n",
" )\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a282abe6-9030-4dfd-ba3d-ab3b18c6b068",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def predict_with_bin_quantized(query):\n",
" query_serialized_vec = serialize_f32_from_np(fe.get_embeddings([query])[0])\n",
" \n",
" retrived_results = db.execute(f\"\"\"\n",
" select\n",
" rowid,\n",
" distance\n",
" from vec_items_{model_name_normalized}_1\n",
" where embedding match vec_quantize_binary(:query_serialized_vec)\n",
" order by distance\n",
" limit 2;\n",
" \"\"\", {\"query_serialized_vec\": query_serialized_vec}).fetchall()\n",
" \n",
" return history.iloc[[row for row,dist in retrived_results]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8175b13-2356-4c9f-a512-a7075cc3684e",
"metadata": {},
"outputs": [],
"source": [
"%timeit predict_with_bin_quantized(query=\"mail box\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfe61dc9-4ed8-4180-8b09-bd40c8aa9467",
"metadata": {},
"outputs": [],
"source": [
"predict_with_bin_quantized(query=\"canada news\")"
]
},
{
"cell_type": "markdown",
"id": "b8a038ce-b277-4ec9-9993-25aac0359f15",
"metadata": {},
"source": [
"#### Approach 2 just using the binary quantization & re-scoring"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "521547c7-44e7-4607-a139-9d312015352e",
"metadata": {},
"outputs": [],
"source": [
"db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[{embeddings_size}], embedding_coarse bit[{embeddings_size}])\")\n",
"# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[768], embedding_coarse bit[768])\")\n",
"# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[128], embedding_coarse bit[128])\")\n",
"\n",
"with db:\n",
" for idx, vec in enumerate(embeddings):\n",
" embedding = serialize_f32_from_np(vec)\n",
" db.execute(\n",
" f\"INSERT INTO vec_items_{model_name_normalized}_2(rowid, embedding, embedding_coarse) VALUES (?, ?, vec_quantize_binary(?))\",\n",
" [idx, embedding, embedding], # Convert vector to binary format\n",
" )\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d9c2d65-2040-4af8-928c-c4e04469974e",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def predict_coarse(query):\n",
" query_serialized_vec = serialize_f32_from_np(fe.get_embeddings([query])[0])\n",
" \n",
" retrived_results = db.execute(f\"\"\"\n",
" with coarse_matches as (\n",
" select\n",
" rowid,\n",
" embedding\n",
" from vec_items_{model_name_normalized}_2\n",
" where embedding_coarse match vec_quantize_binary(:query_serialized_vec)\n",
" order by distance\n",
" limit 200\n",
" )\n",
" select\n",
" rowid,\n",
" vec_distance_cosine(embedding, :query_serialized_vec)\n",
" from coarse_matches\n",
" order by 2\n",
" limit 2;\n",
" \"\"\", {\"query_serialized_vec\": query_serialized_vec}).fetchall()\n",
" return history.iloc[[row for row,dist in retrived_results]]\n",
" \n",
" # final_res = history.iloc[[row for row,dist in retrived_results]]\n",
" # final_res['distance'] = [dist for row,dist in retrived_results]\n",
" # return final_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3dd33aaa-0766-478c-983f-6f890659c30f",
"metadata": {},
"outputs": [],
"source": [
"%timeit predict_coarse(query=\"scheduler\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c22429d8-a559-49e8-80b7-195b685cd23b",
"metadata": {},
"outputs": [],
"source": [
"predict_coarse(query=\"usa news\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab5aae20-10b1-4b02-8da4-6cec58f64960",
"metadata": {},
"outputs": [],
"source": [
"db_size = db.execute(\"PRAGMA page_count;\").fetchone()[0] * db.execute(\"PRAGMA page_size;\").fetchone()[0]\n",
"print(f\"Estimated in-memory SQLite DB size: {db_size / (1024)**2} mb\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3db5a04-27ce-42c1-8c7a-f538da4f24b3",
"metadata": {},
"outputs": [],
"source": [
"DISK_DB_PATH = \"temp_semantic_vec.db\"\n",
"\n",
"# Save the in-memory database to disk\n",
"disk_db = sqlite3.connect(DISK_DB_PATH)\n",
"db.backup(disk_db) # Copy in-memory DB to file\n",
"disk_db.close()\n",
"\n",
"# Get file size\n",
"db_size = os.path.getsize(DISK_DB_PATH)\n",
"print(f\"Size of SQLite database file: {db_size / (1024)**2} mb\")\n"
]
},
{
"cell_type": "markdown",
"id": "21bb0f44-3423-4734-a72c-09856ec1d277",
"metadata": {},
"source": [
"#### Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b121abcd-8a52-4732-a5a1-27de52537d26",
"metadata": {},
"outputs": [],
"source": [
"golden_data = pd.read_csv(\"../data/chidam_golden_query.csv\", usecols=['search_query', 'url'])\n",
"print(len(golden_data))\n",
"golden_data.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33ba1159-5e05-40a7-b0d5-a1a16a45add9",
"metadata": {},
"outputs": [],
"source": [
"def validate(pred_fn):\n",
" eval_rows = []\n",
" print(f\"Validating approach `{pred_fn.__name__}`:\")\n",
" correct = 0\n",
" for idx, (query, actual) in golden_data.iterrows():\n",
" retrieved = pred_fn(query)['url'].values.tolist()\n",
" if actual in retrieved:\n",
" correct += 1\n",
" eval_row = run_traditional_eval(idx, query, [actual], retrieved, retrieved_distances=None, k=2)\n",
" eval_rows.append(eval_row)\n",
" # else:\n",
" # print(query, actual, retrieved)\n",
" print(f\"correct count = {correct}\")\n",
" print(f\"recall = {correct/len(golden_data)}\")\n",
" print(\"\\n\")\n",
" return pd.DataFrame(eval_rows)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8a991aa-1864-47e2-bdef-b1b368c66c59",
"metadata": {},
"outputs": [],
"source": [
"eval_df = validate(predict_with_bin_quantized)\n",
"eval_df[['precision@2', 'recall@2', 'ndcg@2', 'reciprocal_rank', 'average_precision']].mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d4ce022-a0c8-4bb2-833c-231305796785",
"metadata": {},
"outputs": [],
"source": [
"eval_df = validate(predict_coarse)\n",
"eval_df[['precision@2', 'recall@2', 'ndcg@2', 'reciprocal_rank', 'average_precision']].mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09a2f878-5d6a-4269-8833-0fa0768c2335",
"metadata": {},
"outputs": [],
"source": [
"# Validating approach `predict_coarse`:\n",
"# correct count = 14\n",
"# recall = 0.2857142857142857\n",
"\n",
"\n",
"# precision@2 0.142857\n",
"# recall@2 0.285714\n",
"# ndcg@2 0.263118\n",
"# reciprocal_rank 0.255102\n",
"# average_precision 0.183673\n",
"# dtype: float64\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}