notebooks/kg_exploration.ipynb (1,142 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "80d8736e-c8d1-4165-a772-607f1f2812f1",
"metadata": {},
"outputs": [],
"source": [
"import sqlite3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2ff7facf-49e4-4074-9e06-08d0a9915fde",
"metadata": {},
"outputs": [],
"source": [
"import spacy\n",
"from tqdm import tqdm\n",
"\n",
"nlp = spacy.load(\"en_core_web_sm\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "879f1b19-4cc9-4604-8049-0aa7f96b1f6e",
"metadata": {},
"outputs": [],
"source": [
"# Add the project root directory to the Python path\n",
"import os\n",
"import sys\n",
"\n",
"project_root = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
"sys.path.append(project_root)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f04980e3-18c9-4503-99a6-4463e7e8141d",
"metadata": {},
"outputs": [],
"source": [
"from src.metrics import run_traditional_eval"
]
},
{
"cell_type": "markdown",
"id": "e4070835-b2ed-4304-ba80-115d9a359626",
"metadata": {},
"source": [
"#### Fetch top frecent items"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d8257221-4fc8-4c09-ba35-d7b421d65b54",
"metadata": {},
"outputs": [],
"source": [
"row_limit = 10000\n",
"GENERATE_TOPIC = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6af8ae3e-bc9b-4515-be5d-259e28d7fc75",
"metadata": {},
"outputs": [],
"source": [
"firefox_conn = sqlite3.connect(\"../data/places.sqlite\") \n",
"firefox_cursor = firefox_conn.cursor()\n",
"\n",
"input_data = firefox_cursor.execute(f\"\"\"\n",
"WITH TOP_FRECENT_PLACES AS\n",
"(SELECT p.url, p.title, p.description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash\n",
"FROM moz_places p\n",
"WHERE p.title NOTNULL\n",
"AND url not like '%google.com/search?%'\n",
"ORDER BY frecency DESC\n",
"LIMIT {row_limit}\n",
") \n",
"\n",
", TOP_PLACES_INFO AS\n",
"(select * from TOP_FRECENT_PLACES\n",
"UNION\n",
"\n",
"SELECT p.url, p.title, p.description, p.id AS place_id, p.frecency, p.origin_id, p.url_hash\n",
"FROM moz_places p\n",
"WHERE p.id in (select distinct(place_id) from moz_inputhistory)\n",
")\n",
", KEYWORDS_INFO AS\n",
"(SELECT \n",
" ih.place_id, \n",
" json_group_array(\n",
" json_object(\n",
" 'keyword', ih.input,\n",
" 'use_count', ih.use_count\n",
" )\n",
" ) AS keyword_data\n",
"FROM \n",
" moz_inputhistory ih\n",
"WHERE ih.input != ''\n",
"GROUP BY \n",
" ih.place_id\n",
"ORDER BY \n",
" ih.use_count DESC\n",
")\n",
"\n",
", DOMAIN_INFO AS\n",
"(SELECT \n",
" id AS origin_id, \n",
" host, \n",
" CAST(frecency AS REAL) / (SELECT SUM(frecency) * 1.0 FROM moz_origins WHERE frecency IS NOT NULL) AS domain_frecency\n",
"FROM \n",
" moz_origins\n",
"WHERE \n",
" frecency IS NOT NULL\n",
")\n",
"\n",
"SELECT p.*, kw.keyword_data, d.host, d.domain_frecency \n",
"FROM TOP_PLACES_INFO p\n",
"LEFT JOIN KEYWORDS_INFO kw\n",
" ON p.place_id = kw.place_id\n",
"LEFT JOIN DOMAIN_INFO d\n",
" ON p.origin_id = d.origin_id\n",
"ORDER BY p.frecency DESC\n",
"\n",
"\"\"\").fetchall()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff6e5a21-bb6e-4b17-b469-ea853f25deb3",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"input_data_df = pd.DataFrame(input_data, \n",
" columns=['url', 'title', 'description', 'place_id', 'frecency', 'origin_id', 'url_hash', 'keyword_data', 'host', 'domain_frecency'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "818a07d8-a6f1-422c-9496-51e28e826bac",
"metadata": {},
"outputs": [],
"source": [
"def extract_additional_path_info(row):\n",
" url = row['url']\n",
" host = row['host']\n",
" path = url.replace(f\"https://{host}\", \"\").replace(f\"http://{host}\", \"\")\n",
" path = path.strip(\"/\")\n",
" path = path.replace(\".html\", \"\").replace(\".htm\", \"\")\n",
" path_info = path.split(\"/\")\n",
" return path_info\n",
"\n",
"def extract_tags_batch(df):\n",
" # Combine title and description into a single text column\n",
" texts = (df['title'].fillna('') + \" \" + df['description'].fillna('')).str.strip()\n",
" \n",
" # Process texts in batch using spaCy's pipe\n",
" docs = nlp.pipe(texts, disable=[\"ner\"]) # Disable unnecessary components for speed\n",
"\n",
" # Extract tags for each document\n",
" tags_list = []\n",
" for doc in docs:\n",
" tags = set()\n",
" \n",
" # Extract noun chunks and proper nouns\n",
" # for chunk in doc.noun_chunks:\n",
" # tags.add(chunk.text.strip().lower())\n",
" for token in doc:\n",
" if token.pos_ in [\"ADJ\", \"PROPN\", \"NOUN\"] and not token.is_stop:\n",
" tags.add(token.text.strip().lower())\n",
" \n",
" tags_list.append(list(tags)) # Append the tags for this document\n",
" \n",
" return tags_list"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fba770c4-f70f-4b62-a5bd-60ba797d279d",
"metadata": {},
"outputs": [],
"source": [
"input_data_df['path_info'] = input_data_df.apply(lambda row: extract_additional_path_info(row), axis=1)\n",
"input_data_df['tags'] = extract_tags_batch(input_data_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2873a369-137b-46f9-b6f5-6290cffcd505",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"input_data_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4486e17a-5f30-49df-8032-d9c3fe1b2555",
"metadata": {},
"outputs": [],
"source": [
"input_data_df['domain_frecency'].describe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06600fb3-428d-43d9-a4fc-50799b8e441e",
"metadata": {},
"outputs": [],
"source": [
"input_data_df['tags'].values[:20]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cf74cb7-13b4-4e12-b63b-cc8f6ce56ede",
"metadata": {},
"outputs": [],
"source": [
"input_data_df.sample(20).T"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8c99325-5928-4df4-8700-245f5477f44e",
"metadata": {},
"outputs": [],
"source": [
"input_data_df.sample(20)['title'].values"
]
},
{
"cell_type": "markdown",
"id": "ac60d084-de58-4a37-87c8-8debdd881c32",
"metadata": {},
"source": [
"#### Extract the topics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d914222-b55d-4456-9616-ddab30ec28c7",
"metadata": {},
"outputs": [],
"source": [
"from gliner import GLiNER\n",
"\n",
"gliner_model = GLiNER.from_pretrained(\"urchade/gliner_largev2\")\n",
"\n",
"labels = [ \"Arts & Entertainment\",\n",
" \"Business and Consumer Services\",\n",
" \"Community and Society\",\n",
" \"Computers Electronics and Technology\",\n",
" \"Ecommerce & Shopping\",\n",
" \"Finance\",\n",
" \"Food and Drink\",\n",
" \"Gambling\",\n",
" \"Games\",\n",
" \"Health\",\n",
" \"Heavy Industry and Engineering\",\n",
" \"Hobbies and Leisure\",\n",
" \"Home and Garden\",\n",
" \"Jobs and Career\",\n",
" \"Law and Government\",\n",
" \"Lifestyle\",\n",
" \"News & Media Publishers\",\n",
" \"Pets and Animals\",\n",
" \"Reference Materials\",\n",
" \"Science and Education\",\n",
" \"Sports\",\n",
" \"Travel and Tourism\",\n",
" \"Vehicles\",\n",
" \"Adult\"\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95893b70-18f9-427f-b453-f1357811086a",
"metadata": {},
"outputs": [],
"source": [
"texts = (input_data_df['title'].fillna('') + \" \" + input_data_df['description'].fillna('')).values.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2321b3d-3c68-437b-b42b-78c58a9f7503",
"metadata": {},
"outputs": [],
"source": [
"len(texts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0817bfb8-45f1-4426-be32-d65adce44a69",
"metadata": {},
"outputs": [],
"source": [
"texts[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f75dc90-9dfb-4ce2-a935-01d9aa00f696",
"metadata": {},
"outputs": [],
"source": [
"## Very first time set this to True and then switch to False and read from saved file\n",
"# GENERATE_TOPIC = False\n",
"\n",
"if GENERATE_TOPIC:\n",
" topics = []\n",
" for text in tqdm(texts):\n",
" entities = gliner_model.predict_entities(text, labels, threshold=0.3)\n",
" themes = list({entity[\"label\"] for entity in entities})\n",
" topics.append(themes)\n",
" input_data_df['topics'] = topics\n",
" input_data_df.to_parquet(\"../data/input_data_df.parquet\", index=False)\n",
"else:\n",
" input_data_df_bkp = pd.read_parquet(\"../data/input_data_df.parquet\")\n",
" topics_lkp = input_data_df_bkp.set_index('url_hash')['topics'].to_dict()\n",
" input_data_df['topics'] = input_data_df['url_hash'].map(topics_lkp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "083bd3da-180a-4062-80f2-7c8974d5fb4a",
"metadata": {},
"outputs": [],
"source": [
"len(input_data_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4016c11a-673c-40e6-a304-8704db9cd363",
"metadata": {},
"outputs": [],
"source": [
"input_data_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d495bd4-e90f-4cdd-a4dc-21f396adef6a",
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"\n",
"tags_counter = Counter()\n",
"tags_counter.update([tag for tags in input_data_df['tags'].values.tolist() for tag in tags if tag.isalnum()])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5df477b2-3168-488c-bbed-62e12e7a88ab",
"metadata": {},
"outputs": [],
"source": [
"len(tags_counter)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33c508b5-086c-4302-976f-4d47c916fe2e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"tags_counter.most_common(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e73e61d1-5150-4d9f-b3ac-0be08a01adb4",
"metadata": {},
"outputs": [],
"source": [
"path_info_counter = Counter()\n",
"path_info_counter.update(\n",
" [path_i for path_info in input_data_df['path_info'].values.tolist() for path_i in path_info if len(path_i) > 2 and path_i.isalpha()]\n",
")\n",
"print(len(path_info_counter))\n",
"path_info_counter.most_common(30)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "822772b7-e4d1-4e07-bbab-d18fcd1f1c86",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"def extract_keywords_adhoc(json_str):\n",
" try:\n",
" # Parse the string as JSON\n",
" data = json.loads(json_str)\n",
" # Extract the \"keyword\" field from each dictionary\n",
" return [item[\"keyword\"] for item in data]\n",
" except (json.JSONDecodeError, TypeError):\n",
" # Handle invalid JSON or None\n",
" return []\n",
"\n",
"keywords_list = input_data_df['keyword_data'].apply(extract_keywords_adhoc).values.tolist()\n",
"kws_counter = Counter()\n",
"kws_counter.update([kw for kws in keywords_list for kw in kws])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4351367b-5f3e-48f2-9d18-de02c296c413",
"metadata": {},
"outputs": [],
"source": [
"len(kws_counter)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a38ccf7-2e41-475c-8529-7f622d68f37f",
"metadata": {},
"outputs": [],
"source": [
"input_data_df['keyword_data'][(~input_data_df['keyword_data'].isna())]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d680ab4d-ab9b-4f78-a067-414ceb48f433",
"metadata": {},
"outputs": [],
"source": [
"def generate_entity_rltn_score(src_entity, src_entity_type, relation, tgt_entity, score):\n",
" return (src_entity, src_entity_type, relation, tgt_entity, score)\n",
"\n",
"def extract_keyword_entities_rltn_score(df, entity_name, entity_type, relation, tgt_entity_name, score_col=None):\n",
" sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n",
" for ers_info, tgt_val in zip(sel_df[entity_name].apply(json.loads), sel_df[tgt_entity_name]):\n",
" for ers in ers_info:\n",
" for key, val in ers.items():\n",
" # print(key, val, tgt_val)\n",
" if key == entity_type:\n",
" src_entity = val\n",
" if score_col and key == score_col:\n",
" score = 1+val\n",
" else:\n",
" score = None\n",
" yield generate_entity_rltn_score(src_entity, entity_type, relation, tgt_val, score)\n",
" \n",
" \n",
"def extract_domain_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n",
" sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name, score_col]].reset_index(drop=True)\n",
" for idx, row in sel_df.iterrows():\n",
" yield generate_entity_rltn_score(row[entity_name], entity_name, relation, row[tgt_entity_name], row[score_col])\n",
"\n",
"def extract_path_info_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n",
" sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n",
" for idx, row in sel_df.iterrows():\n",
" for entity_val in row[entity_name]:\n",
" if len(entity_val) > 2 and entity_val.isalpha():\n",
" yield generate_entity_rltn_score(entity_val, entity_name, relation, row[tgt_entity_name], score_col) \n",
"\n",
"def extract_tags_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n",
" sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n",
" for idx, row in sel_df.iterrows():\n",
" for entity_val in row[entity_name]:\n",
" if len(entity_val) > 2 and entity_val.isalnum():\n",
" yield generate_entity_rltn_score(entity_val, 'tag', relation, row[tgt_entity_name], score_col) \n",
"\n",
"def extract_topics_entities_rltn_score(df, entity_name, relation, tgt_entity_name, score_col=None):\n",
" sel_df = df.loc[~df[entity_name].isna(), [entity_name, tgt_entity_name]].reset_index(drop=True)\n",
" for idx, row in sel_df.iterrows():\n",
" for entity_val in row[entity_name]:\n",
" if len(entity_val) > 1:\n",
" yield generate_entity_rltn_score(entity_val, 'topic', relation, row[tgt_entity_name], score_col) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce114ec5-3359-46bc-8e5f-959f4a536aa3",
"metadata": {},
"outputs": [],
"source": [
"# print(next(generate_entity_rltn_score('cloud', 'keyword', 'refers_to', 'place_id1', 0.391895954969)))\n",
"# print(next(extract_entities_rltn_score(input_data_df, 'keyword_data', 'keyword', 'refers_to', 'place_id', 'use_count')))\n",
"keyword_ers = [ers for ers in (extract_keyword_entities_rltn_score(input_data_df, 'keyword_data', 'keyword', 'refers_to', 'url_hash', 'use_count'))]\n",
"print(len(keyword_ers))\n",
"keyword_ers[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fde16d52-a2b7-48c0-a688-b1d457d758e0",
"metadata": {},
"outputs": [],
"source": [
"domain_ers = [ers for ers in extract_domain_entities_rltn_score(input_data_df, 'host', 'contains', 'url_hash', 'domain_frecency')]\n",
"print(len(domain_ers))\n",
"domain_ers[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfbb48d3-4fed-4e8b-b43d-734a610e832a",
"metadata": {},
"outputs": [],
"source": [
"path_info_ers = [ers for ers in extract_path_info_entities_rltn_score(input_data_df, 'path_info', 'parses_to', 'url_hash', 1.0)]\n",
"print(len(path_info_ers))\n",
"path_info_ers[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c70fedd-1250-42e0-be9d-147fc3fcd253",
"metadata": {},
"outputs": [],
"source": [
"tags_ers = [ers for ers in extract_tags_entities_rltn_score(input_data_df, 'tags', 'tagged_has', 'url_hash', 1.0)]\n",
"print(len(tags_ers))\n",
"tags_ers[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53f90fed-3730-4c1a-a3b4-409b2cf35894",
"metadata": {},
"outputs": [],
"source": [
"topics_ers = [ers for ers in extract_topics_entities_rltn_score(input_data_df, 'topics', 'belongs_to', 'url_hash', 1.0)]\n",
"print(len(topics_ers))\n",
"topics_ers[:5]"
]
},
{
"cell_type": "markdown",
"id": "9e59e845-2bce-45b6-97a6-247f136dcbfa",
"metadata": {},
"source": [
"#### Combining all entities and relationships"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af778f6a-91fd-4e86-9460-e97787ec8fc5",
"metadata": {},
"outputs": [],
"source": [
"len(keyword_ers) + len(domain_ers ) + len(path_info_ers) + len(tags_ers) + len(topics_ers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d8cf0c5-f176-4b22-ae0d-121164b872a8",
"metadata": {},
"outputs": [],
"source": [
"ers_df = pd.DataFrame(keyword_ers + domain_ers + path_info_ers + tags_ers + topics_ers,\n",
" columns=['entity', 'entity_type', 'relation', 'url_hash', 'score'])\n",
"# + len(domain_ers ) + len(path_info_ers) + len(tags_ers) + len(topics_ers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9f55106-d807-4816-bb80-d9166f960051",
"metadata": {},
"outputs": [],
"source": [
"ers_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcba8ec6-5ba9-48a0-8c7e-f729997aade9",
"metadata": {},
"outputs": [],
"source": [
"# Create a new SQLite database\n",
"db_path = \"../data/ml_kg.db\"\n",
"conn = sqlite3.connect(db_path)\n",
"\n",
"# Create the ml_kg_info table\n",
"conn.execute(\"\"\"\n",
"CREATE TABLE IF NOT EXISTS ml_kg_info (\n",
" entity TEXT NOT NULL,\n",
" entity_type TEXT NOT NULL,\n",
" relation TEXT NOT NULL,\n",
" url_hash INTEGER NOT NULL,\n",
" score REAL NOT NULL\n",
");\n",
"\"\"\")\n",
"\n",
"# Insert data from DataFrame into the table\n",
"ers_df.to_sql(\"ml_kg_info\", conn, if_exists=\"append\", index=False)\n",
"\n",
"# Attach the moz_places table from places.sqlite\n",
"places_db_path = \"../data/places.sqlite\"\n",
"conn.execute(f\"ATTACH DATABASE '{places_db_path}' AS places_db;\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ace6fe6-49fe-4cf5-8b21-d11507f7372a",
"metadata": {},
"outputs": [],
"source": [
"def extract_tags_for_queries(queries):\n",
" texts = queries[::]\n",
" \n",
" docs = nlp.pipe(texts, disable=[\"ner\"])\n",
"\n",
" tags_list = []\n",
" for doc in docs:\n",
" tags = set()\n",
" \n",
" for token in doc:\n",
" # print(token.pos_)\n",
" if token.pos_ in [\"ADJ\", \"PROPN\", \"NOUN\"] and not token.is_stop:\n",
" tags.add(token.text.strip().lower())\n",
" \n",
" tags_list.append(list(tags)) # Append the tags for this document\n",
" \n",
" return tags_list\n",
"\n",
"def infer_topics(queries, pbar=True):\n",
" topics = []\n",
" if pbar:\n",
" for query in tqdm(queries):\n",
" entities = gliner_model.predict_entities(query, labels, threshold=0.3)\n",
" themes = list({entity[\"label\"] for entity in entities})\n",
" topics.append(themes)\n",
" else:\n",
" for query in queries:\n",
" entities = gliner_model.predict_entities(query, labels, threshold=0.3)\n",
" themes = list({entity[\"label\"] for entity in entities})\n",
" topics.append(themes)\n",
" return topics\n",
" \n",
"def fetch_entity_relations_with_keywords(conn, search_keyword, search_tags, search_topics):\n",
" # Convert the list of search keywords into a string suitable for SQL\n",
" if not search_keyword:\n",
" raise ValueError(\"search_keywords list cannot be empty.\")\n",
"\n",
" keyword_placeholder = f\"'{search_keyword}'\"\n",
" # print(\"keyword_placeholder = \", keyword_placeholder)\n",
" \n",
" tag_placeholder = ', '.join(f\"'{tag}'\" for tag in search_tags)\n",
" # print(\"tag_placeholder = \", tag_placeholder)\n",
"\n",
" topic_placeholder = ', '.join(f\"'{topic}'\" for topic in search_topics)\n",
" # print(\"topic_placeholder = \", topic_placeholder)\n",
"\n",
" # Define the query with the dynamic IN clause\n",
" query = f\"\"\"\n",
" WITH entity_relations_info AS (\n",
" SELECT \n",
" m.entity,\n",
" m.entity_type,\n",
" m.relation,\n",
" m.url_hash,\n",
" m.score,\n",
" p.url,\n",
" p.title,\n",
" p.frecency\n",
" FROM \n",
" ml_kg_info m\n",
" JOIN \n",
" places_db.moz_places p\n",
" ON \n",
" m.url_hash = p.url_hash\n",
" WHERE\n",
" (m.entity IN ({keyword_placeholder}) AND\n",
" m.entity_type = 'keyword') OR\n",
" (m.entity IN ({tag_placeholder}) AND\n",
" m.entity_type = 'tag') OR\n",
" (m.entity IN ({topic_placeholder}) AND\n",
" m.entity_type = 'topic')\n",
" ORDER BY \n",
" m.score DESC\n",
" )\n",
"\n",
" SELECT \n",
" url_hash, \n",
" url, \n",
" title,\n",
" SUM(score) AS total_score\n",
" FROM \n",
" entity_relations_info\n",
" GROUP BY \n",
" url_hash, url, title\n",
" ORDER BY \n",
" total_score DESC;\n",
" \n",
" \"\"\"\n",
" \n",
" \n",
"\n",
" results = pd.read_sql_query(query, conn)\n",
" return results\n",
"\n",
"\n",
"# search_query = \"kanba\"\n",
"search_query = \"healthy food and education\"\n",
"# search_keywords = search_query.split(\" \")\n",
"search_tags = extract_tags_for_queries([search_query])[0]\n",
"search_topics = infer_topics([search_query])[0]\n",
"\n",
"\n",
"results = fetch_entity_relations_with_keywords(conn, search_query, search_tags, search_topics)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bcaace9-7008-4864-a9ec-1cb5b2fe9122",
"metadata": {},
"outputs": [],
"source": [
"results.head(10).T"
]
},
{
"cell_type": "markdown",
"id": "1e4aacd0-84b1-4c51-91f5-28113431eb8a",
"metadata": {},
"source": [
"#### Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d23cd571-8ecf-4efc-a335-abb338f7188e",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def fetch_ground_truths():\n",
" val_cursor = firefox_conn.cursor()\n",
"\n",
" val_data = val_cursor.execute(\n",
" \"\"\"\n",
" SELECT ih.input AS keyword,\n",
" p.url_hash,\n",
" ih.use_count,\n",
" p.url\n",
" FROM moz_inputhistory ih\n",
" JOIN moz_places p\n",
" ON ih.place_id = p.id\n",
" WHERE input != ''\n",
" ORDER BY keyword, use_count DESC\n",
" \"\"\"\n",
" ).fetchall()\n",
" return val_data\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a97ce8a3-30bd-4c11-a301-96a573f1fcc3",
"metadata": {},
"outputs": [],
"source": [
"val_data = fetch_ground_truths()\n",
"actuals_df = pd.DataFrame(val_data, columns=['keyword', 'url_hash', 'use_count', 'url'])\n",
"to_be_predicted_queries = actuals_df.groupby('keyword')['url_hash'].agg(list).reset_index()\n",
"print(len(to_be_predicted_queries))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2daaaa4-ded5-414a-86b3-9d7ce1049430",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def perform_traditional_evals(to_be_predicted_queries, use_tags=True, use_topics=True):\n",
" eval_rows = []\n",
" for idx, row in to_be_predicted_queries.iterrows():\n",
" if (idx+1) % 50 == 0:\n",
" print(f\" {idx+1} queries evaluated\")\n",
" search_query = search_keyword = row['keyword']\n",
" # print(f\"search_keyword = {search_keyword}\")\n",
" relevant_docs = row['url_hash']\n",
" \n",
" search_tags = extract_tags_for_queries([search_query])[0] if use_tags else []\n",
" search_topics = infer_topics([search_query], pbar=False)[0] if use_topics else []\n",
" # print(f\"search_tags = {search_tags}\")\n",
" # print(f\"search_topics = {search_topics}\")\n",
" results = fetch_entity_relations_with_keywords(conn, search_keyword, search_tags, search_topics).head(2)\n",
" retrieved_docs = []\n",
" if len(results) > 0:\n",
" retrieved_docs = results['url_hash'].values.tolist()\n",
" eval_row = run_traditional_eval(idx, search_keyword, relevant_docs, retrieved_docs, retrieved_distances=None, k=2)\n",
" eval_rows.append(eval_row)\n",
" return pd.DataFrame(eval_rows)\n"
]
},
{
"cell_type": "markdown",
"id": "d2039503-1086-469f-b43b-99e036c09da5",
"metadata": {},
"source": [
"#### Use keywords + tags + topics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f42d10c-4922-47f2-bbda-f33ad6451df2",
"metadata": {},
"outputs": [],
"source": [
"eval_df = perform_traditional_evals(to_be_predicted_queries)\n",
"# print(eval_df)\n",
"eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()\n",
"\n",
"\n",
"# # keywords + tags + topics \n",
"# precision@2 0.539931\n",
"# recall@2 0.964699\n",
"# ndcg@2 0.590278\n",
"# reciprocal_rank 0.968750\n",
"# average_precision 0.470486"
]
},
{
"cell_type": "markdown",
"id": "c28c25da-8f48-49d8-990c-14b4971b718e",
"metadata": {},
"source": [
"#### Use only keywords "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a7eac15-88ab-46f6-a103-a1f2efd3ed3b",
"metadata": {},
"outputs": [],
"source": [
"eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=False, use_topics=False)\n",
"# print(eval_df)\n",
"eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()"
]
},
{
"cell_type": "markdown",
"id": "691d0db3-f830-4bca-afdd-65a94ae14289",
"metadata": {},
"source": [
"#### Use keywords + Tags and no topics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e2cf0bd-c4b6-4deb-a452-5a0ec106b375",
"metadata": {},
"outputs": [],
"source": [
"eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=True, use_topics=False)\n",
"# print(eval_df)\n",
"eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()"
]
},
{
"cell_type": "markdown",
"id": "dedd0007-e773-4ce9-93cd-f6857dfa35dd",
"metadata": {},
"source": [
"#### Use keywords + topics and no tags"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2bbaf7b5-bb16-4199-87b1-5518f8a7dee6",
"metadata": {},
"outputs": [],
"source": [
"eval_df = perform_traditional_evals(to_be_predicted_queries, use_tags=False, use_topics=True)\n",
"print(len(eval_df))\n",
"eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8f2b42e-28a0-4473-961a-0e9da9245512",
"metadata": {},
"outputs": [],
"source": [
"# firefox_conn.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f66a4296-4230-4a85-9ff8-bb253039438d",
"metadata": {},
"outputs": [],
"source": [
"to_be_predicted_queries"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9ad2bca1-f250-43a9-9c0b-91ce161e1ab0",
"metadata": {},
"outputs": [],
"source": [
"golden_queries = pd.read_csv(\"../data/chidam_golden_query.csv\", usecols=['search_query', 'url'])\n",
"print(len(golden_queries))\n",
"golden_queries\n",
"# set(golden_queries['search_query'].tolist()).intersection(set(to_be_predicted_queries['keyword'].values.tolist()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6240272-3660-48a4-8c1d-56ef7920e239",
"metadata": {},
"outputs": [],
"source": [
"set(golden_queries['search_query'].tolist()).intersection(set(to_be_predicted_queries['keyword'].values.tolist()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "535ec425-4768-4819-97c6-f08231ec49ee",
"metadata": {},
"outputs": [],
"source": [
"set(golden_queries['search_query'].tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a08ad79b-6646-4d7b-b0a7-986de8042bbf",
"metadata": {},
"outputs": [],
"source": [
"set(to_be_predicted_queries['keyword'].values.tolist())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf10a452-28d5-4182-a28a-401f0453b201",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9626d69d-d2e0-4d18-82bb-97cad56793db",
"metadata": {},
"outputs": [],
"source": [
"def get_url_hash_batch(golden_queries, firefox_conn):\n",
" # Convert URLs into a tuple for the IN clause\n",
" urls = tuple(golden_queries['url'].tolist())\n",
"\n",
" # Query all URL hashes in one go\n",
" query = f\"\"\"\n",
" SELECT url, url_hash\n",
" FROM moz_places\n",
" WHERE url IN ({','.join(['?'] * len(urls))})\n",
" \"\"\"\n",
" \n",
" # Execute the query and fetch results\n",
" cursor = firefox_conn.cursor()\n",
" results = cursor.execute(query, urls).fetchall()\n",
"\n",
" # Convert results into a DataFrame\n",
" return pd.DataFrame(results, columns=[\"url\", \"url_hash\"])\n",
"\n",
"# Example usage\n",
"url_hashes = get_url_hash_batch(golden_queries, firefox_conn)\n",
"golden_queries_updated = golden_queries.merge(url_hashes, on='url', how='inner')\n",
"golden_queries_updated = golden_queries_updated.groupby('search_query')['url_hash'].agg(list).reset_index()\\\n",
" .rename(columns={'search_query': 'keyword'})\n",
"print(len(golden_queries_updated))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c158c0c-c48d-4d19-af00-f901e21bbaf5",
"metadata": {},
"outputs": [],
"source": [
"golden_queries_updated"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eb3fe3aa-2a40-4f7e-8f76-6f3cecae904f",
"metadata": {},
"outputs": [],
"source": [
"## keywords + tags + topics\n",
"golden_eval_df = perform_traditional_evals(golden_queries_updated)\n",
"print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n",
"print(len(golden_eval_df))\n",
"golden_eval_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bce4474-14a5-4d1d-ba0b-57799390fad7",
"metadata": {},
"outputs": [],
"source": [
"## only keywords\n",
"golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=False, use_topics=False)\n",
"print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n",
"print(len(golden_eval_df))\n",
"golden_eval_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eca306a0-d4bf-4a43-847d-005e2cf035f0",
"metadata": {},
"outputs": [],
"source": [
"## keywords + tags and no topics\n",
"golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=True, use_topics=False)\n",
"print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n",
"print(len(golden_eval_df))\n",
"golden_eval_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "755945e6-3a76-43c3-b59a-58a3de65f25d",
"metadata": {},
"outputs": [],
"source": [
"## keywords + topics and no tags\n",
"golden_eval_df = perform_traditional_evals(golden_queries_updated, use_tags=False, use_topics=True)\n",
"print(golden_eval_df[['precision@2','recall@2','ndcg@2','reciprocal_rank','average_precision']].mean())\n",
"print(len(golden_eval_df))\n",
"golden_eval_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b9b4300-3936-4272-b590-aecf639e4984",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}