notebooks/kg_exploration.ipynb (1,142 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "80d8736e-c8d1-4165-a772-607f1f2812f1", "metadata": {}, "outputs": [], "source": [ "import sqlite3" ] }, { "cell_type": "code", "execution_count": 2, "id": "2ff7facf-49e4-4074-9e06-08d0a9915fde", "metadata": {}, "outputs": [], "source": [ "import spacy\n", "from tqdm import tqdm\n", "\n", "nlp = spacy.load(\"en_core_web_sm\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "879f1b19-4cc9-4604-8049-0aa7f96b1f6e", "metadata": {}, "outputs": [], "source": [ "# Add the project root directory to the Python path\n", "import os\n", "import sys\n", "\n", "project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", "sys.path.append(project_root)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f04980e3-18c9-4503-99a6-4463e7e8141d", "metadata": {}, "outputs": [], "source": [ "from src.metrics import run_traditional_eval" ] }, { "cell_type": "markdown", "id": "e4070835-b2ed-4304-ba80-115d9a359626", "metadata": {}, "source": [ "#### Fetch top frecent items" ] }, { "cell_type": "code", "execution_count": 6, "id": "d8257221-4fc8-4c09-ba35-d7b421d65b54", "metadata": {}, "outputs": [], "source": [ "row_limit = 10000\n", "GENERATE_TOPIC = False" ] }, { "cell_type": "code", "execution_count": null, "id": "6af8ae3e-bc9b-4515-be5d-259e28d7fc75", "metadata": {}, "outputs": [], "source": [ "firefox_conn = sqlite3.connect(\"../data/places.sqlite\") \n", "firefox_cursor = firefox_conn.cursor()\n", "\n", "input_data = firefox_cursor.execute(f\"\"\"\n", "WITH TOP_FRECENT_PLACES AS\n", "(SELECT p.url, p.title, p.description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash\n", "FROM moz_places p\n", "WHERE p.title NOTNULL\n", "AND url not like '%google.com/search?%'\n", "ORDER BY frecency DESC\n", "LIMIT {row_limit}\n", ") \n", "\n", ", TOP_PLACES_INFO AS\n", "(select * from TOP_FRECENT_PLACES\n", "UNION\n", "\n", "SELECT p.url, p.title, p.description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash\n", "FROM moz_places p\n", "WHERE p.id in (select distinct(place_id) from moz_inputhistory)\n", ")\n", ", KEYWORDS_INFO AS\n", "(SELECT \n", " ih.place_id, \n", " json_group_array(\n", " json_object(\n", " 'keyword', ih.input,\n", " 'use_count', ih.use_count\n", " )\n", " ) AS keyword_data\n", "FROM \n", " moz_inputhistory ih\n", "WHERE ih.input != ''\n", "GROUP BY \n", " ih.place_id\n", "ORDER BY \n", " ih.use_count DESC\n", ")\n", "\n", ", DOMAIN_INFO AS\n", "(SELECT \n", " id AS origin_id, \n", " host, \n", " CAST(frecency AS REAL) / (SELECT SUM(frecency) * 1.0 FROM moz_origins WHERE frecency IS NOT NULL) AS domain_frecency\n", "FROM \n", " moz_origins\n", "WHERE \n", " frecency IS NOT NULL\n", ")\n", "\n", "SELECT p.*, kw.keyword_data, d.host, d.domain_frecency \n", "FROM TOP_PLACES_INFO p\n", "LEFT JOIN KEYWORDS_INFO kw\n", " ON p.place_id = kw.place_id\n", "LEFT JOIN DOMAIN_INFO d\n", " ON p.origin_id = d.origin_id\n", "ORDER BY p.frecency DESC\n", "\n", "\"\"\").fetchall()" ] }, { "cell_type": "code", "execution_count": null, "id": "ff6e5a21-bb6e-4b17-b469-ea853f25deb3", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "input_data_df = pd.DataFrame(input_data, \n", " columns=['url', 'title', 'description', 'place_id', 'frecency', 'origin_id', 'url_hash', 'keyword_data', 'host', 'domain_frecency'])" ] }, { "cell_type": "code", "execution_count": null, "id": "818a07d8-a6f1-422c-9496-51e28e826bac", "metadata": {}, "outputs": [], "source": [ "def extract_additional_path_info(row):\n", " url = row['url']\n", " host = row['host']\n", " path = url.replace(f\"https://{host}\", \"\").replace(f\"http://{host}\", \"\")\n", " path = path.strip(\"/\")\n", " path = path.replace(\".html\", \"\").replace(\".htm\", \"\")\n", " path_info = path.split(\"/\")\n", " return path_info\n", "\n", "def extract_tags_batch(df):\n", " # Combine title and description into a single text column\n", " texts = (df['title'].fillna('') + \" \" + df['description'].fillna('')).str.strip()\n", " \n", " # Process texts in batch using spaCy's pipe\n", " docs = nlp.pipe(texts, disable=[\"ner\"]) # Disable unnecessary components for speed\n", "\n", " # Extract tags for each document\n", " tags_list = []\n", " for doc in docs:\n", " tags = set()\n", " \n", " # Extract noun chunks and proper nouns\n", " # for chunk in doc.noun_chunks:\n", " # tags.add(chunk.text.strip().lower())\n", " for token in doc:\n", " if token.pos_ in [\"ADJ\", \"PROPN\", \"NOUN\"] and not token.is_stop:\n", " tags.add(token.text.strip().lower())\n", " \n", " tags_list.append(list(tags)) # Append the tags for this document\n", " \n", " return tags_list" ] }, { "cell_type": "code", "execution_count": null, "id": "fba770c4-f70f-4b62-a5bd-60ba797d279d", "metadata": {}, "outputs": [], "source": [ "input_data_df['path_info'] = input_data_df.apply(lambda row: extract_additional_path_info(row), axis=1)\n", "input_data_df['tags'] = extract_tags_batch(input_data_df)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2873a369-137b-46f9-b6f5-6290cffcd505", "metadata": { "scrolled": true }, "outputs": [], "source": [ "input_data_df" ] }, { "cell_type": "code", "execution_count": null, "id": "4486e17a-5f30-49df-8032-d9c3fe1b2555", "metadata": {}, "outputs": [], "source": [ "input_data_df['domain_frecency'].describe()" ] }, { "cell_type": "code", "execution_count": null, "id": "06600fb3-428d-43d9-a4fc-50799b8e441e", "metadata": {}, "outputs": [], "source": [ "input_data_df['tags'].values[:20]" ] }, { "cell_type": "code", "execution_count": null, "id": "9cf74cb7-13b4-4e12-b63b-cc8f6ce56ede", "metadata": {}, "outputs": [], "source": [ "input_data_df.sample(20).T" ] }, { "cell_type": "code", "execution_count": null, "id": "e8c99325-5928-4df4-8700-245f5477f44e", "metadata": {}, "outputs": [], "source": [ "input_data_df.sample(20)['title'].values" ] }, { "cell_type": "markdown", "id": "ac60d084-de58-4a37-87c8-8debdd881c32", "metadata": {}, "source": [ "#### Extract the topics" ] }, { "cell_type": "code", "execution_count": null, "id": "3d914222-b55d-4456-9616-ddab30ec28c7", "metadata": {}, "outputs": [], "source": [ "from gliner import GLiNER\n", "\n", "gliner_model = GLiNER.from_pretrained(\"urchade/gliner_largev2\")\n", "\n", "labels = [ \"Arts & Entertainment\",\n", " \"Business and Consumer Services\",\n", " \"Community and Society\",\n", " \"Computers Electronics and Technology\",\n", " \"Ecommerce & Shopping\",\n", " \"Finance\",\n", " \"Food and Drink\",\n", " \"Gambling\",\n", " \"Games\",\n", " \"Health\",\n", " \"Heavy Industry and Engineering\",\n", " \"Hobbies and Leisure\",\n", " \"Home and Garden\",\n", " \"Jobs and Career\",\n", " \"Law and Government\",\n", " \"Lifestyle\",\n", " \"News & Media Publishers\",\n", " \"Pets and Animals\",\n", " \"Reference Materials\",\n", " \"Science and Education\",\n", " \"Sports\",\n", " \"Travel and Tourism\",\n", " \"Vehicles\",\n", " \"Adult\"\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "95893b70-18f9-427f-b453-f1357811086a", "metadata": {}, "outputs": [], "source": [ "texts = (input_data_df['title'].fillna('') + \" \" + input_data_df['description'].fillna('')).values.tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "a2321b3d-3c68-437b-b42b-78c58a9f7503", "metadata": {}, "outputs": [], "source": [ "len(texts)" ] }, { "cell_type": "code", "execution_count": null, "id": "0817bfb8-45f1-4426-be32-d65adce44a69", "metadata": {}, "outputs": [], "source": [ "texts[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "6f75dc90-9dfb-4ce2-a935-01d9aa00f696", "metadata": {}, "outputs": [], "source": [ "## Very first time set this to True and then switch to False and read from saved file\n", "# GENERATE_TOPIC = False\n", "\n", "if GENERATE_TOPIC:\n", " topics = []\n", " for text in tqdm(texts):\n", " entities = gliner_model.predict_entities(text, labels, threshold=0.3)\n", " themes = list({entity[\"label\"] for entity in entities})\n", " topics.append(themes)\n", " input_data_df['topics'] = topics\n", " input_data_df.to_parquet(\"../data/input_data_df.parquet\", index=False)\n", "else:\n", " input_data_df_bkp = pd.read_parquet(\"../data/input_data_df.parquet\")\n", " topics_lkp = input_data_df_bkp.set_index('url_hash')['topics'].to_dict()\n", " input_data_df['topics'] = input_data_df['url_hash'].map(topics_lkp)" ] }, { "cell_type": "code", "execution_count": null, "id": "083bd3da-180a-4062-80f2-7c8974d5fb4a", "metadata": {}, "outputs": [], "source": [ "len(input_data_df)" ] }, { "cell_type": "code", "execution_count": null, "id": "4016c11a-673c-40e6-a304-8704db9cd363", "metadata": {}, "outputs": [], "source": [ "input_data_df" ] }, { "cell_type": "code", "execution_count": null, "id": "1d495bd4-e90f-4cdd-a4dc-21f396adef6a", "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "\n", "tags_counter = Counter()\n", "tags_counter.update([tag for tags in input_data_df['tags'].values.tolist() for tag in tags if tag.isalnum()])" ] }, { "cell_type": "code", "execution_count": null, "id": "5df477b2-3168-488c-bbed-62e12e7a88ab", "metadata": {}, "outputs": [], "source": [ "len(tags_counter)" ] }, { "cell_type": "code", "execution_count": null, "id": "33c508b5-086c-4302-976f-4d47c916fe2e", "metadata": { "scrolled": true }, "outputs": [], "source": [ "tags_counter.most_common(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "e73e61d1-5150-4d9f-b3ac-0be08a01adb4", "metadata": {}, "outputs": [], "source": [ "path_info_counter = Counter()\n", "path_info_counter.update(\n", " [path_i for path_info in input_data_df['path_info'].values.tolist() for path_i in path_info if len(path_i) > 2 and path_i.isalpha()]\n", ")\n", "print(len(path_info_counter))\n", "path_info_counter.most_common(30)" ] }, { "cell_type": "code", "execution_count": null, "id": "822772b7-e4d1-4e07-bbab-d18fcd1f1c86", "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "def extract_keywords_adhoc(json_str):\n", " try:\n", " # Parse the string as JSON\n", " data = json.loads(json_str)\n", " # Extract the \"keyword\" field from each dictionary\n", " return [item[\"keyword\"] for item in data]\n", " except (json.JSONDecodeError, TypeError):\n", " # Handle invalid JSON or None\n", " return []\n", "\n", "keywords_list = input_data_df['keyword_data'].apply(extract_keywords_adhoc).values.tolist()\n", "kws_counter = Counter()\n", "kws_counter.update([kw for kws in keywords_list for kw in kws])" ] }, { "cell_type": "code", "execution_count": null, "id": "4351367b-5f3e-48f2-9d18-de02c296c413", "metadata": {}, "outputs": [], "source": [ "len(kws_counter)" ] }, { "cell_type": "code", "execution_count": null, "id": "3a38ccf7-2e41-475c-8529-7f622d68f37f", "metadata": {}, "outputs": [], "source": [ "input_data_df['keyword_data'][(~input_data_df['keyword_data'].isna())]" ] }, { "cell_type": "code", "execution_count": null, "id": "d680ab4d-ab9b-4f78-a067-414ceb48f433", "metadata": {}, "outputs": [], "source": [ "def generate_entity_rltn_score(src_entity, src_entity_type, relation, tgt_entity, score):\n", " return (src_entity, src_entity_type, relation, tgt_entity, score)\n", "\n", "def extract_keyword_entities_rltn_score(df, entity_name, entity_type, relation, tgt_entity_name, score_col=None):\n", " sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n", " for ers_info, tgt_val in zip(sel_df[entity_name].apply(json.loads), sel_df[tgt_entity_name]):\n", " for ers in ers_info:\n", " for key, val in ers.items():\n", " # print(key, val, tgt_val)\n", " if key == entity_type:\n", " src_entity = val\n", " if score_col and key == score_col:\n", " score = 1+val\n", " else:\n", " score = None\n", " yield generate_entity_rltn_score(src_entity, entity_type, relation, tgt_val, score)\n", " \n", " \n", "def extract_domain_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n", " sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name, score_col]].reset_index(drop=True)\n", " for idx, row in sel_df.iterrows():\n", " yield generate_entity_rltn_score(row[entity_name], entity_name, relation, row[tgt_entity_name], row[score_col])\n", "\n", "def extract_path_info_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n", " sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n", " for idx, row in sel_df.iterrows():\n", " for entity_val in row[entity_name]:\n", " if len(entity_val) > 2 and entity_val.isalpha():\n", " yield generate_entity_rltn_score(entity_val, entity_name, relation, row[tgt_entity_name], score_col) \n", "\n", "def extract_tags_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n", " sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n", " for idx, row in sel_df.iterrows():\n", " for entity_val in row[entity_name]:\n", " if len(entity_val) > 2 and entity_val.isalnum():\n", " yield generate_entity_rltn_score(entity_val, 'tag', relation, row[tgt_entity_name], score_col) \n", "\n", "def extract_topics_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n", " sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n", " for idx, row in sel_df.iterrows():\n", " for entity_val in row[entity_name]:\n", " if len(entity_val) > 1:\n", " yield generate_entity_rltn_score(entity_val, 'topic', relation, row[tgt_entity_name], score_col) " ] }, { "cell_type": "code", "execution_count": null, "id": "ce114ec5-3359-46bc-8e5f-959f4a536aa3", "metadata": {}, "outputs": [], "source": [ "# print(next(generate_entity_rltn_score('cloud', 'keyword', 'refers_to', 'place_id1', 0.391895954969)))\n", "# print(next(extract_entities_rltn_score(input_data_df, 'keyword_data', 'keyword', 'refers_to', 'place_id', 'use_count')))\n", "keyword_ers = [ers for ers in (extract_keyword_entities_rltn_score(input_data_df, 'keyword_data', 'keyword', 'refers_to', 'url_hash', 'use_count'))]\n", "print(len(keyword_ers))\n", "keyword_ers[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "fde16d52-a2b7-48c0-a688-b1d457d758e0", "metadata": {}, "outputs": [], "source": [ "domain_ers = [ers for ers in extract_domain_entities_rltn_score(input_data_df, 'host', 'contains', 'url_hash', 'domain_frecency')]\n", "print(len(domain_ers))\n", "domain_ers[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "bfbb48d3-4fed-4e8b-b43d-734a610e832a", "metadata": {}, "outputs": [], "source": [ "path_info_ers = [ers for ers in extract_path_info_entities_rltn_score(input_data_df, 'path_info', 'parses_to', 'url_hash', 1.0)]\n", "print(len(path_info_ers))\n", "path_info_ers[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "4c70fedd-1250-42e0-be9d-147fc3fcd253", "metadata": {}, "outputs": [], "source": [ "tags_ers = [ers for ers in extract_tags_entities_rltn_score(input_data_df, 'tags', 'tagged_has', 'url_hash', 1.0)]\n", "print(len(tags_ers))\n", "tags_ers[:5]" ] }, { "cell_type": "code", "execution_count": null, "id": "53f90fed-3730-4c1a-a3b4-409b2cf35894", "metadata": {}, "outputs": [], "source": [ "topics_ers = [ers for ers in extract_topics_entities_rltn_score(input_data_df, 'topics', 'belongs_to', 'url_hash', 1.0)]\n", "print(len(topics_ers))\n", "topics_ers[:5]" ] }, { "cell_type": "markdown", "id": "9e59e845-2bce-45b6-97a6-247f136dcbfa", "metadata": {}, "source": [ "#### Combining all entities and relationships" ] }, { "cell_type": "code", "execution_count": null, "id": "af778f6a-91fd-4e86-9460-e97787ec8fc5", "metadata": {}, "outputs": [], "source": [ "len(keyword_ers) + len(domain_ers ) + len(path_info_ers) + len(tags_ers) + len(topics_ers)" ] }, { "cell_type": "code", "execution_count": null, "id": "7d8cf0c5-f176-4b22-ae0d-121164b872a8", "metadata": {}, "outputs": [], "source": [ "ers_df = pd.DataFrame(keyword_ers + domain_ers + path_info_ers + tags_ers + topics_ers,\n", " columns=['entity', 'entity_type', 'relation', 'url_hash', 'score'])\n", "# + len(domain_ers ) + len(path_info_ers) + len(tags_ers) + len(topics_ers)" ] }, { "cell_type": "code", "execution_count": null, "id": "e9f55106-d807-4816-bb80-d9166f960051", "metadata": {}, "outputs": [], "source": [ "ers_df" ] }, { "cell_type": "code", "execution_count": null, "id": "bcba8ec6-5ba9-48a0-8c7e-f729997aade9", "metadata": {}, "outputs": [], "source": [ "# Create a new SQLite database\n", "db_path = \"../data/ml_kg.db\"\n", "conn = sqlite3.connect(db_path)\n", "\n", "# Create the ml_kg_info table\n", "conn.execute(\"\"\"\n", "CREATE TABLE IF NOT EXISTS ml_kg_info (\n", " entity TEXT NOT NULL,\n", " entity_type TEXT NOT NULL,\n", " relation TEXT NOT NULL,\n", " url_hash INTEGER NOT NULL,\n", " score REAL NOT NULL\n", ");\n", "\"\"\")\n", "\n", "# Insert data from DataFrame into the table\n", "ers_df.to_sql(\"ml_kg_info\", conn, if_exists=\"append\", index=False)\n", "\n", "# Attach the moz_places table from places.sqlite\n", "places_db_path = \"../data/places.sqlite\"\n", "conn.execute(f\"ATTACH DATABASE '{places_db_path}' AS places_db;\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7ace6fe6-49fe-4cf5-8b21-d11507f7372a", "metadata": {}, "outputs": [], "source": [ "def extract_tags_for_queries(queries):\n", " texts = queries[::]\n", " \n", " docs = nlp.pipe(texts, disable=[\"ner\"])\n", "\n", " tags_list = []\n", " for doc in docs:\n", " tags = set()\n", " \n", " for token in doc:\n", " # print(token.pos_)\n", " if token.pos_ in [\"ADJ\", \"PROPN\", \"NOUN\"] and not token.is_stop:\n", " tags.add(token.text.strip().lower())\n", " \n", " tags_list.append(list(tags)) # Append the tags for this document\n", " \n", " return tags_list\n", "\n", "def infer_topics(queries, pbar=True):\n", " topics = []\n", " if pbar:\n", " for query in tqdm(queries):\n", " entities = gliner_model.predict_entities(query, labels, threshold=0.3)\n", " themes = list({entity[\"label\"] for entity in entities})\n", " topics.append(themes)\n", " else:\n", " for query in queries:\n", " entities = gliner_model.predict_entities(query, labels, threshold=0.3)\n", " themes = list({entity[\"label\"] for entity in entities})\n", " topics.append(themes)\n", " return topics\n", " \n", "def fetch_entity_relations_with_keywords(conn, search_keyword, search_tags, search_topics):\n", " # Convert the list of search keywords into a string suitable for SQL\n", " if not search_keyword:\n", " raise ValueError(\"search_keywords list cannot be empty.\")\n", "\n", " keyword_placeholder = f\"'{search_keyword}'\"\n", " # print(\"keyword_placeholder = \", keyword_placeholder)\n", " \n", " tag_placeholder = ', '.join(f\"'{tag}'\" for tag in search_tags)\n", " # print(\"tag_placeholder = \", tag_placeholder)\n", "\n", " topic_placeholder = ', '.join(f\"'{topic}'\" for topic in search_topics)\n", " # print(\"topic_placeholder = \", topic_placeholder)\n", "\n", " # Define the query with the dynamic IN clause\n", " query = f\"\"\"\n", " WITH entity_relations_info AS (\n", " SELECT \n", " m.entity,\n", " m.entity_type,\n", " m.relation,\n", " m.url_hash,\n", " m.score,\n", " p.url,\n", " p.title,\n", " p.frecency\n", " FROM \n", " ml_kg_info m\n", " JOIN \n", " places_db.moz_places p\n", " ON \n", " m.url_hash = p.url_hash\n", " WHERE\n", " (m.entity IN ({keyword_placeholder}) AND\n", " m.entity_type = 'keyword') OR\n", " (m.entity IN ({tag_placeholder}) AND\n", " m.entity_type = 'tag') OR\n", " (m.entity IN ({topic_placeholder}) AND\n", " m.entity_type = 'topic')\n", " ORDER BY \n", " m.score DESC\n", " )\n", "\n", " SELECT \n", " url_hash, \n", " url, \n", " title,\n", " SUM(score) AS total_score\n", " FROM \n", " entity_relations_info\n", " GROUP BY \n", " url_hash, url, title\n", " ORDER BY \n", " total_score DESC;\n", " \n", " \"\"\"\n", " \n", " \n", "\n", " results = pd.read_sql_query(query, conn)\n", " return results\n", "\n", "\n", "# search_query = \"kanba\"\n", "search_query = \"healthy food and education\"\n", "# search_keywords = search_query.split(\" \")\n", "search_tags = extract_tags_for_queries([search_query])[0]\n", "search_topics = infer_topics([search_query])[0]\n", "\n", "\n", "results = fetch_entity_relations_with_keywords(conn, search_query, search_tags, search_topics)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1bcaace9-7008-4864-a9ec-1cb5b2fe9122", "metadata": {}, "outputs": [], "source": [ "results.head(10).T" ] }, { "cell_type": "markdown", "id": "1e4aacd0-84b1-4c51-91f5-28113431eb8a", "metadata": {}, "source": [ "#### Validation" ] }, { "cell_type": "code", "execution_count": null, "id": "d23cd571-8ecf-4efc-a335-abb338f7188e", "metadata": {}, "outputs": [], "source": [ "\n", "def fetch_ground_truths():\n", " val_cursor = firefox_conn.cursor()\n", "\n", " val_data = val_cursor.execute(\n", " \"\"\"\n", " SELECT ih.input AS keyword,\n", " p.url_hash,\n", " ih.use_count,\n", " p.url\n", " FROM moz_inputhistory ih\n", " JOIN moz_places p\n", " ON ih.place_id = p.id\n", " WHERE input != ''\n", " ORDER BY keyword, use_count DESC\n", " \"\"\"\n", " ).fetchall()\n", " return val_data\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "a97ce8a3-30bd-4c11-a301-96a573f1fcc3", "metadata": {}, "outputs": [], "source": [ "val_data = fetch_ground_truths()\n", "actuals_df = pd.DataFrame(val_data, columns=['keyword', 'url_hash', 'use_count', 'url'])\n", "to_be_predicted_queries = actuals_df.groupby('keyword')['url_hash'].agg(list).reset_index()\n", "print(len(to_be_predicted_queries))" ] }, { "cell_type": "code", "execution_count": null, "id": "f2daaaa4-ded5-414a-86b3-9d7ce1049430", "metadata": { "scrolled": true }, "outputs": [], "source": [ "def perform_traditional_evals(to_be_predicted_queries, use_tags=True, use_topics=True):\n", " eval_rows = []\n", " for idx, row in to_be_predicted_queries.iterrows():\n", " if (idx+1) % 50 == 0:\n", " print(f\" {idx+1} queries evaluated\")\n", " search_query = search_keyword = row['keyword']\n", " # print(f\"search_keyword = {search_keyword}\")\n", " relevant_docs = row['url_hash']\n", " \n", " search_tags = extract_tags_for_queries([search_query])[0] if use_tags else []\n", " search_topics = infer_topics([search_query], pbar=False)[0] if use_topics else []\n", " # print(f\"search_tags = {search_tags}\")\n", " # print(f\"search_topics = {search_topics}\")\n", " results = fetch_entity_relations_with_keywords(conn, search_keyword, search_tags, search_topics).head(2)\n", " retrieved_docs = []\n", " if len(results) > 0:\n", " retrieved_docs = results['url_hash'].values.tolist()\n", " eval_row = run_traditional_eval(idx, search_keyword, relevant_docs, retrieved_docs, retrieved_distances=None, k=2)\n", " eval_rows.append(eval_row)\n", " return pd.DataFrame(eval_rows)\n" ] }, { "cell_type": "markdown", "id": "d2039503-1086-469f-b43b-99e036c09da5", "metadata": {}, "source": [ "#### Use keywords + tags + topics" ] }, { "cell_type": "code", "execution_count": null, "id": "0f42d10c-4922-47f2-bbda-f33ad6451df2", "metadata": {}, "outputs": [], "source": [ "eval_df = perform_traditional_evals(to_be_predicted_queries)\n", "# print(eval_df)\n", "eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()\n", "\n", "\n", "# # keywords + tags + topics \n", "# precision@2 0.539931\n", "# recall@2 0.964699\n", "# ndcg@2 0.590278\n", "# reciprocal_rank 0.968750\n", "# average_precision 0.470486" ] }, { "cell_type": "markdown", "id": "c28c25da-8f48-49d8-990c-14b4971b718e", "metadata": {}, "source": [ "#### Use only keywords " ] }, { "cell_type": "code", "execution_count": null, "id": "0a7eac15-88ab-46f6-a103-a1f2efd3ed3b", "metadata": {}, "outputs": [], "source": [ "eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=False, use_topics=False)\n", "# print(eval_df)\n", "eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()" ] }, { "cell_type": "markdown", "id": "691d0db3-f830-4bca-afdd-65a94ae14289", "metadata": {}, "source": [ "#### Use keywords + Tags and no topics" ] }, { "cell_type": "code", "execution_count": null, "id": "4e2cf0bd-c4b6-4deb-a452-5a0ec106b375", "metadata": {}, "outputs": [], "source": [ "eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=True, use_topics=False)\n", "# print(eval_df)\n", "eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()" ] }, { "cell_type": "markdown", "id": "dedd0007-e773-4ce9-93cd-f6857dfa35dd", "metadata": {}, "source": [ "#### Use keywords + topics and no tags" ] }, { "cell_type": "code", "execution_count": null, "id": "2bbaf7b5-bb16-4199-87b1-5518f8a7dee6", "metadata": {}, "outputs": [], "source": [ "eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=False, use_topics=True)\n", "print(len(eval_df))\n", "eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()" ] }, { "cell_type": "code", "execution_count": null, "id": "f8f2b42e-28a0-4473-961a-0e9da9245512", "metadata": {}, "outputs": [], "source": [ "# firefox_conn.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "f66a4296-4230-4a85-9ff8-bb253039438d", "metadata": {}, "outputs": [], "source": [ "to_be_predicted_queries" ] }, { "cell_type": "code", "execution_count": null, "id": "9ad2bca1-f250-43a9-9c0b-91ce161e1ab0", "metadata": {}, "outputs": [], "source": [ "golden_queries = pd.read_csv(\"../data/chidam_golden_query.csv\", usecols=['search_query', 'url'])\n", "print(len(golden_queries))\n", "golden_queries\n", "# set(golden_queries['search_query'].tolist()).intersection(set(to_be_predicted_queries['keyword'].values.tolist()))" ] }, { "cell_type": "code", "execution_count": null, "id": "c6240272-3660-48a4-8c1d-56ef7920e239", "metadata": {}, "outputs": [], "source": [ "set(golden_queries['search_query'].tolist()).intersection(set(to_be_predicted_queries['keyword'].values.tolist()))" ] }, { "cell_type": "code", "execution_count": null, "id": "535ec425-4768-4819-97c6-f08231ec49ee", "metadata": {}, "outputs": [], "source": [ "set(golden_queries['search_query'].tolist())" ] }, { "cell_type": "code", "execution_count": null, "id": "a08ad79b-6646-4d7b-b0a7-986de8042bbf", "metadata": {}, "outputs": [], "source": [ "set(to_be_predicted_queries['keyword'].values.tolist())" ] }, { "cell_type": "code", "execution_count": null, "id": "bf10a452-28d5-4182-a28a-401f0453b201", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9626d69d-d2e0-4d18-82bb-97cad56793db", "metadata": {}, "outputs": [], "source": [ "def get_url_hash_batch(golden_queries, firefox_conn):\n", " # Convert URLs into a tuple for the IN clause\n", " urls = tuple(golden_queries['url'].tolist())\n", "\n", " # Query all URL hashes in one go\n", " query = f\"\"\"\n", " SELECT url, url_hash\n", " FROM moz_places\n", " WHERE url IN ({','.join(['?'] * len(urls))})\n", " \"\"\"\n", " \n", " # Execute the query and fetch results\n", " cursor = firefox_conn.cursor()\n", " results = cursor.execute(query, urls).fetchall()\n", "\n", " # Convert results into a DataFrame\n", " return pd.DataFrame(results, columns=[\"url\", \"url_hash\"])\n", "\n", "# Example usage\n", "url_hashes = get_url_hash_batch(golden_queries, firefox_conn)\n", "golden_queries_updated = golden_queries.merge(url_hashes, on='url', how='inner')\n", "golden_queries_updated = golden_queries_updated.groupby('search_query')['url_hash'].agg(list).reset_index()\\\n", " .rename(columns={'search_query': 'keyword'})\n", "print(len(golden_queries_updated))" ] }, { "cell_type": "code", "execution_count": null, "id": "5c158c0c-c48d-4d19-af00-f901e21bbaf5", "metadata": {}, "outputs": [], "source": [ "golden_queries_updated" ] }, { "cell_type": "code", "execution_count": null, "id": "eb3fe3aa-2a40-4f7e-8f76-6f3cecae904f", "metadata": {}, "outputs": [], "source": [ "## keywords + tags + topics\n", "golden_eval_df = perform_traditional_evals(golden_queries_updated)\n", "print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n", "print(len(golden_eval_df))\n", "golden_eval_df" ] }, { "cell_type": "code", "execution_count": null, "id": "1bce4474-14a5-4d1d-ba0b-57799390fad7", "metadata": {}, "outputs": [], "source": [ "## only keywords\n", "golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=False, use_topics=False)\n", "print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n", "print(len(golden_eval_df))\n", "golden_eval_df" ] }, { "cell_type": "code", "execution_count": null, "id": "eca306a0-d4bf-4a43-847d-005e2cf035f0", "metadata": {}, "outputs": [], "source": [ "## keywords + tags and no topics\n", "golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=True, use_topics=False)\n", "print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n", "print(len(golden_eval_df))\n", "golden_eval_df" ] }, { "cell_type": "code", "execution_count": null, "id": "755945e6-3a76-43c3-b59a-58a3de65f25d", "metadata": {}, "outputs": [], "source": [ "## keywords + topics and no tags\n", "golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=False, use_topics=True)\n", "print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n", "print(len(golden_eval_df))\n", "golden_eval_df" ] }, { "cell_type": "code", "execution_count": null, "id": "6b9b4300-3936-4272-b590-aecf639e4984", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }