notebooks/bin_quantization_approach.ipynb (591 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "d83680ff-b52f-4c4a-a565-62774533f944", "metadata": {}, "source": [ "Explore binary quantization\n", "\n", "https://alexgarcia.xyz/sqlite-vec/guides/binary-quant.html" ] }, { "cell_type": "code", "execution_count": null, "id": "563e4f5f-8c6b-4058-9c2c-00da1329b03c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "id": "335c4f20-f649-4ea4-ae44-9e300718447b", "metadata": {}, "outputs": [], "source": [ "import sqlite3\n", "import sqlite_vec\n", "from typing import List\n", "import struct" ] }, { "cell_type": "code", "execution_count": null, "id": "978fda31-9915-42c0-8bd5-ef3e8c95d8f3", "metadata": {}, "outputs": [], "source": [ "# Add the project root directory to the Python path\n", "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", "sys.path.append(project_root)" ] }, { "cell_type": "code", "execution_count": null, "id": "a0d86492-e8be-4acf-9635-6e6e54bbb5d4", "metadata": {}, "outputs": [], "source": [ "from src.constants import EMBEDDING_MODELS_DICT\n", "from src.feature_extractor import FeatureExtractor\n", "from src.metrics import run_traditional_eval" ] }, { "cell_type": "code", "execution_count": null, "id": "fd1327d6-fe08-4d9f-8030-9a3cbec4cac9", "metadata": {}, "outputs": [], "source": [ "!python -V" ] }, { "cell_type": "code", "execution_count": null, "id": "abe464ec-6331-4a58-9b95-95f547edb899", "metadata": {}, "outputs": [], "source": [ "!python -m pip freeze| grep sqlite" ] }, { "cell_type": "code", "execution_count": null, "id": "0651ef95-2ca5-49eb-89c0-4e49a4641328", "metadata": {}, "outputs": [], "source": [ "# !export LDFLAGS=\"-L/opt/homebrew/opt/sqlite/lib\"\n", "# !export CPPFLAGS=\"-I/opt/homebrew/opt/sqlite/include\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c9b15421-08df-40c2-8837-1ca5cc1acc2f", "metadata": {}, "outputs": [], "source": [ "db = sqlite3.connect(\":memory:\")\n", "db.enable_load_extension(True)\n", "sqlite_vec.load(db)\n", "db.enable_load_extension(False)\n", "\n", "sqlite_version, vec_version = db.execute(\n", " \"select sqlite_version(), vec_version()\"\n", ").fetchone()\n", "print(f\"sqlite_version={sqlite_version}, vec_version={vec_version}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c0d71dc0-f5d5-4c69-b11f-fff4dea8e57a", "metadata": {}, "outputs": [], "source": [ "res = db.execute(f\"\"\"select vec_quantize_binary(\n", " '[-0.73, -0.80, 0.12, -0.73, 0.79, -0.11, 0.23, 0.97]'\n", ");\"\"\").fetchall()" ] }, { "cell_type": "code", "execution_count": null, "id": "bb556110-3b8a-4d59-b042-555d0c6dc5ed", "metadata": {}, "outputs": [], "source": [ "res" ] }, { "cell_type": "code", "execution_count": null, "id": "e94280d9-c129-4dd5-a258-b9be9dda6a61", "metadata": {}, "outputs": [], "source": [ "type(res)" ] }, { "cell_type": "code", "execution_count": null, "id": "31809d3a-70c4-446d-baae-055d2d8dba1f", "metadata": {}, "outputs": [], "source": [ "# int(res[0][0])\n", "byte_value = res[0][0]\n", "binary_representation = bin(int.from_bytes(byte_value, \"big\"))\n", "print(f\"Binary Representation: {binary_representation}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "52f6613d-024e-4580-88f0-159665fae6fb", "metadata": {}, "outputs": [], "source": [ "row_limit = 10000" ] }, { "cell_type": "code", "execution_count": null, "id": "8ee5def5-432a-4958-93bc-76b21297be83", "metadata": {}, "outputs": [], "source": [ "firefox_conn = sqlite3.connect(\"../data/places.sqlite\") \n", "firefox_cursor = firefox_conn.cursor()\n", "\n", "input_data = firefox_cursor.execute(f\"\"\"\n", "WITH TOP_FRECENT_PLACES AS\n", "(SELECT p.url, p.title, COALESCE(p.description, '') AS description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash,\n", " p.last_visit_date\n", "FROM moz_places p\n", "WHERE p.title NOTNULL\n", "AND url not like '%google.com/search?%'\n", "ORDER BY frecency DESC\n", "LIMIT {row_limit}\n", ") \n", "\n", "SELECT * FROM TOP_FRECENT_PLACES;\n", "\"\"\").fetchall()" ] }, { "cell_type": "code", "execution_count": null, "id": "576306ff-6db4-4d17-958b-dcd7800b364f", "metadata": {}, "outputs": [], "source": [ "history = pd.read_csv(\"../data/history_output_file.csv\")\n", "# history = pd.DataFrame(input_data, \n", "# columns=['url', 'title', 'description', 'place_id', 'frecency', 'origin_id', 'url_hash', 'last_visit_date'])\n", "history['last_visit_date'] = pd.to_datetime(history['last_visit_date'], unit='us')\n", "\n", "# fill empty last_visit_date with default value \"1970-01-01\"\n", "history['last_visit_date'] = history['last_visit_date'].fillna(pd.to_datetime(\"1970-01-01\"))\n", "history['combined_text'] = history['title'].fillna('') + \" \" + history['description'].fillna('')\n", "history = history.loc[history['combined_text'] != ''].reset_index(drop=True).head(row_limit)" ] }, { "cell_type": "code", "execution_count": null, "id": "295d6da5-c257-4a6b-ae37-485472824d88", "metadata": {}, "outputs": [], "source": [ "history" ] }, { "cell_type": "code", "execution_count": null, "id": "8b14078a-45d0-4694-b2e0-e6a9a3013841", "metadata": {}, "outputs": [], "source": [ "EMBEDDING_MODELS_DICT['Xenova/all-MiniLM-L6-v2']" ] }, { "cell_type": "code", "execution_count": null, "id": "770c11b4-6288-44ea-be69-d0adca51c4e3", "metadata": {}, "outputs": [], "source": [ "EMBEDDING_MODELS_DICT" ] }, { "cell_type": "code", "execution_count": null, "id": "d1cb2048-0f42-4cc1-9daa-8b71764ceb10", "metadata": {}, "outputs": [], "source": [ "model_name, embeddings_size = 'Xenova/all-MiniLM-L6-v2', 384\n", "# model_name, embeddings_size = 'nomic-ai/nomic-embed-text-v1.5', 768\n", "# model_name, embeddings_size = \"Xenova/all-mpnet-base-v2\", 768\n", "# model_name, embeddings_size = 'Xenova/paraphrase-mpnet-base-v2', 768\n", "# model_name, embeddings_size = 'Xenova/all-MiniLM-L12-v2', 384\n", "# model_name, embeddings_size = 'nomic-ai/modernbert-embed-base', 768\n", "fe = FeatureExtractor(EMBEDDING_MODELS_DICT, model_name=model_name)\n", "texts = history['combined_text'].values.tolist()\n", "embeddings = fe.get_embeddings(texts)\n", "embeddings.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "995cbbdd-b4b0-410e-85c3-b0b87a71b889", "metadata": {}, "outputs": [], "source": [ "model_name_normalized = model_name.replace(\"/\",\"_\").replace(\"-\",\"_\").replace(\".\",\"_\")\n", "\n", "# Function to convert float vectors to binary format for SQLite\n", "def serialize_f32_from_np(vector: np.ndarray) -> bytes:\n", " \"\"\"Serializes a NumPy float32 vector into raw bytes format for SQLite.\"\"\"\n", " return struct.pack(f\"{len(vector)}f\", *vector.astype(np.float32)) # Convert to float32\n", "\n", "\n", "items = []\n", "for idx, vec in enumerate(embeddings):\n", " items.append((idx, vec))\n", "\n", "for item in items[:5]:\n", " print(type(item[1][0]))" ] }, { "cell_type": "code", "execution_count": null, "id": "ac2fdb82-52c9-48e2-b0ba-445d6b569cca", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "7f2c5192-7a38-4848-a096-e1c76e228956", "metadata": { "scrolled": true }, "source": [ "#### Approach 1 just using the binary quantization" ] }, { "cell_type": "code", "execution_count": null, "id": "658fd9ad-074d-4d27-b7a6-2fe3f14eef79", "metadata": {}, "outputs": [], "source": [ "db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[{embeddings_size}])\")\n", "# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[768])\")\n", "# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_1 USING vec0(embedding bit[128])\")\n", "\n", "with db:\n", " for idx, vec in enumerate(embeddings):\n", " db.execute(\n", " f\"INSERT INTO vec_items_{model_name_normalized}_1(rowid, embedding) VALUES (?, vec_quantize_binary(?))\",\n", " [idx, serialize_f32_from_np(vec)], # Convert vector to binary format\n", " )\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a282abe6-9030-4dfd-ba3d-ab3b18c6b068", "metadata": {}, "outputs": [], "source": [ "\n", "def predict_with_bin_quantized(query):\n", " query_serialized_vec = serialize_f32_from_np(fe.get_embeddings([query])[0])\n", " \n", " retrived_results = db.execute(f\"\"\"\n", " select\n", " rowid,\n", " distance\n", " from vec_items_{model_name_normalized}_1\n", " where embedding match vec_quantize_binary(:query_serialized_vec)\n", " order by distance\n", " limit 2;\n", " \"\"\", {\"query_serialized_vec\": query_serialized_vec}).fetchall()\n", " \n", " return history.iloc[[row for row,dist in retrived_results]]" ] }, { "cell_type": "code", "execution_count": null, "id": "c8175b13-2356-4c9f-a512-a7075cc3684e", "metadata": {}, "outputs": [], "source": [ "%timeit predict_with_bin_quantized(query=\"mail box\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bfe61dc9-4ed8-4180-8b09-bd40c8aa9467", "metadata": {}, "outputs": [], "source": [ "predict_with_bin_quantized(query=\"canada news\")" ] }, { "cell_type": "markdown", "id": "b8a038ce-b277-4ec9-9993-25aac0359f15", "metadata": {}, "source": [ "#### Approach 2 just using the binary quantization & re-scoring" ] }, { "cell_type": "code", "execution_count": null, "id": "521547c7-44e7-4607-a139-9d312015352e", "metadata": {}, "outputs": [], "source": [ "db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[{embeddings_size}], embedding_coarse bit[{embeddings_size}])\")\n", "# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[768], embedding_coarse bit[768])\")\n", "# db.execute(f\"CREATE VIRTUAL TABLE vec_items_{model_name_normalized}_2 USING vec0(embedding float[128], embedding_coarse bit[128])\")\n", "\n", "with db:\n", " for idx, vec in enumerate(embeddings):\n", " embedding = serialize_f32_from_np(vec)\n", " db.execute(\n", " f\"INSERT INTO vec_items_{model_name_normalized}_2(rowid, embedding, embedding_coarse) VALUES (?, ?, vec_quantize_binary(?))\",\n", " [idx, embedding, embedding], # Convert vector to binary format\n", " )\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0d9c2d65-2040-4af8-928c-c4e04469974e", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "def predict_coarse(query):\n", " query_serialized_vec = serialize_f32_from_np(fe.get_embeddings([query])[0])\n", " \n", " retrived_results = db.execute(f\"\"\"\n", " with coarse_matches as (\n", " select\n", " rowid,\n", " embedding\n", " from vec_items_{model_name_normalized}_2\n", " where embedding_coarse match vec_quantize_binary(:query_serialized_vec)\n", " order by distance\n", " limit 200\n", " )\n", " select\n", " rowid,\n", " vec_distance_cosine(embedding, :query_serialized_vec)\n", " from coarse_matches\n", " order by 2\n", " limit 2;\n", " \"\"\", {\"query_serialized_vec\": query_serialized_vec}).fetchall()\n", " return history.iloc[[row for row,dist in retrived_results]]\n", " \n", " # final_res = history.iloc[[row for row,dist in retrived_results]]\n", " # final_res['distance'] = [dist for row,dist in retrived_results]\n", " # return final_res" ] }, { "cell_type": "code", "execution_count": null, "id": "3dd33aaa-0766-478c-983f-6f890659c30f", "metadata": {}, "outputs": [], "source": [ "%timeit predict_coarse(query=\"scheduler\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c22429d8-a559-49e8-80b7-195b685cd23b", "metadata": {}, "outputs": [], "source": [ "predict_coarse(query=\"usa news\")" ] }, { "cell_type": "code", "execution_count": null, "id": "ab5aae20-10b1-4b02-8da4-6cec58f64960", "metadata": {}, "outputs": [], "source": [ "db_size = db.execute(\"PRAGMA page_count;\").fetchone()[0] * db.execute(\"PRAGMA page_size;\").fetchone()[0]\n", "print(f\"Estimated in-memory SQLite DB size: {db_size / (1024)**2} mb\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d3db5a04-27ce-42c1-8c7a-f538da4f24b3", "metadata": {}, "outputs": [], "source": [ "DISK_DB_PATH = \"temp_semantic_vec.db\"\n", "\n", "# Save the in-memory database to disk\n", "disk_db = sqlite3.connect(DISK_DB_PATH)\n", "db.backup(disk_db) # Copy in-memory DB to file\n", "disk_db.close()\n", "\n", "# Get file size\n", "db_size = os.path.getsize(DISK_DB_PATH)\n", "print(f\"Size of SQLite database file: {db_size / (1024)**2} mb\")\n" ] }, { "cell_type": "markdown", "id": "21bb0f44-3423-4734-a72c-09856ec1d277", "metadata": {}, "source": [ "#### Validation" ] }, { "cell_type": "code", "execution_count": null, "id": "b121abcd-8a52-4732-a5a1-27de52537d26", "metadata": {}, "outputs": [], "source": [ "golden_data = pd.read_csv(\"../data/chidam_golden_query.csv\", usecols=['search_query', 'url'])\n", "print(len(golden_data))\n", "golden_data.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "33ba1159-5e05-40a7-b0d5-a1a16a45add9", "metadata": {}, "outputs": [], "source": [ "def validate(pred_fn):\n", " eval_rows = []\n", " print(f\"Validating approach `{pred_fn.__name__}`:\")\n", " correct = 0\n", " for idx, (query, actual) in golden_data.iterrows():\n", " retrieved = pred_fn(query)['url'].values.tolist()\n", " if actual in retrieved:\n", " correct += 1\n", " eval_row = run_traditional_eval(idx, query, [actual], retrieved, retrieved_distances=None, k=2)\n", " eval_rows.append(eval_row)\n", " # else:\n", " # print(query, actual, retrieved)\n", " print(f\"correct count = {correct}\")\n", " print(f\"recall = {correct/len(golden_data)}\")\n", " print(\"\\n\")\n", " return pd.DataFrame(eval_rows)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c8a991aa-1864-47e2-bdef-b1b368c66c59", "metadata": {}, "outputs": [], "source": [ "eval_df = validate(predict_with_bin_quantized)\n", "eval_df[['precision@2', 'recall@2', 'ndcg@2', 'reciprocal_rank', 'average_precision']].mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "5d4ce022-a0c8-4bb2-833c-231305796785", "metadata": {}, "outputs": [], "source": [ "eval_df = validate(predict_coarse)\n", "eval_df[['precision@2', 'recall@2', 'ndcg@2', 'reciprocal_rank', 'average_precision']].mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "09a2f878-5d6a-4269-8833-0fa0768c2335", "metadata": {}, "outputs": [], "source": [ "# Validating approach `predict_coarse`:\n", "# correct count = 14\n", "# recall = 0.2857142857142857\n", "\n", "\n", "# precision@2 0.142857\n", "# recall@2 0.285714\n", "# ndcg@2 0.263118\n", "# reciprocal_rank 0.255102\n", "# average_precision 0.183673\n", "# dtype: float64\n", "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }