notebooks/Benchmarking.ipynb (524 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "1d01dd88-caac-43d8-a4c2-f1059ca8ee59",
"metadata": {},
"source": [
"This notebook benchmarks ML models (downloaded into the model directory) against various validation sets.\n",
"\n",
"The 'all_users' dataset may not be public, but the single tab dataset is.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "581755c5-3d91-4c1f-b01e-40d9de375c66",
"metadata": {
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb473ed7-aa9e-42fd-9446-a081d9ba9c38",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%pwd\n",
"%cd \"~/Documents/GitHub/smart-tab-grouping\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d82dbb51-bb12-46c4-b532-87a2735caccb",
"metadata": {},
"outputs": [],
"source": [
"from rouge_score import rouge_scorer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3c8d2f3-2be3-4dc8-9461-9ed71947a0e0",
"metadata": {},
"outputs": [],
"source": [
"multitab_tests = pd.read_csv(\"data/individual_tests/private/all_users2.csv\")\n",
"single_tab_tests = pd.read_csv(\"data/individual_tests/single_tab_validation.csv\")\n",
"single_tab_tests.keywords = \"\"\n",
"\n",
"garbled_tests = pd.read_csv(\"data/individual_tests/garbled.csv\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4a81bc3-8bf0-46ab-aebb-c3b1f9b3e2ed",
"metadata": {},
"outputs": [],
"source": [
"garbled_tests.loc[:, \"keywords\"] = \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f60698b7-3cb8-484b-978f-798955a1f729",
"metadata": {},
"outputs": [],
"source": [
"garbled_tests"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6e68219-1f21-4620-a734-09fd140bf8ae",
"metadata": {},
"outputs": [],
"source": [
"from spellchecker import SpellChecker\n",
"spell = SpellChecker()\n",
"spell.word_frequency.load_words(['microsoft', 'apple', 'google', 'bing', 'search', 'duckduckgo', 'yahoo'])\n",
"\n",
"\n",
"def is_clean_string(s: str):\n",
" for word in s.split():\n",
" if (\"'\" in word):\n",
" segments = word.split(\"'\")\n",
" if len(segments) == 1:\n",
" break\n",
" if len(segments) > 2:\n",
" return False\n",
" if len(segments) == 2:\n",
" if len(segments[0]) > 1 and len(segments[1]) > 1:\n",
" return False\n",
" continue # don't check spelling with 's\n",
" if (len(spell.unknown([word])) == 1):\n",
" return False\n",
" last_char = None\n",
" for cur_char in word:\n",
" if last_char is None:\n",
" last_char = cur_char\n",
" continue\n",
" if (not last_char.isalpha()) or (not cur_char.isalpha()):\n",
" last_char = cur_char\n",
" continue\n",
" if cur_char.upper() == cur_char and last_char.lower() == last_char: # switch to uppercase\n",
" return False\n",
" last_char = cur_char\n",
" return True\n",
" \n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e711a63-5ece-4b73-a757-0cba0548572c",
"metadata": {},
"outputs": [],
"source": [
"scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6c987e5-aad7-4d3c-a2df-8696b96b7db4",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import cosine_similarity\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bd3e3c8-907c-4084-a73c-3407d4906a00",
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "575f3907-0cac-44a8-ad32-41bb01077c84",
"metadata": {},
"outputs": [],
"source": [
"embedder = pipeline(\"feature-extraction\", model=\"sentence-transformers/all-MiniLM-L6-v2\", device=-1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b485fdef-ee22-441a-9cf5-45d8ef571a30",
"metadata": {},
"outputs": [],
"source": [
"def cos_sim(s1, s2):\n",
" embeddings = [np.mean(embedder(s)[0], axis=0) for s in [s1, s2]]\n",
" similarity = cosine_similarity(embeddings[0].reshape(1,-1), embeddings[1].reshape(1,-1)).squeeze()\n",
" return similarity\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc7b43aa-0adb-4c63-a1c7-47478d24e405",
"metadata": {},
"outputs": [],
"source": [
"cos_sim(\"Dogs\", \"Apple\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51e52356-e47c-4910-863a-bec23e119ad5",
"metadata": {},
"outputs": [],
"source": [
"def compute_scores(row, pred_key=None):\n",
" scores = scorer.score(row['label'], row[pred_key])\n",
" return {\n",
" 'rouge1': scores['rouge1'].fmeasure,\n",
" 'rouge2': scores['rouge2'].fmeasure,\n",
" 'rougeL': scores['rougeL'].fmeasure,\n",
" 'pred_len': len(row[pred_key]),\n",
" 'label_len': len(row['label']),\n",
" 'cos_sim': cos_sim(row['label'], row[pred_key]),\n",
" 'clean': 1 if is_clean_string(row[pred_key]) else 0\n",
" }\n",
"\n",
"def compute_scores_no_label(row, pred_key=None):\n",
" return {\n",
" 'clean': 1 if is_clean_string(row[pred_key]) else 0\n",
" }\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c14856f7-8329-419c-9e2e-1c96a07d3984",
"metadata": {},
"outputs": [],
"source": [
"def get_avg_scores(input_df: DataFrame, compare_column: str):\n",
" scorer = compute_scores_no_label if 'label' not in input_df.columns else compute_scores\n",
" rouge_scores_df = input_df.apply(partial(scorer, pred_key=compare_column) , axis=1, result_type='expand')\n",
" average_scores = rouge_scores_df.mean().to_dict()\n",
" return average_scores\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc08ad0e-5750-4823-8181-6b6c58bdb11b",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"/Users/Rrando/Documents/GitHub/smart-tab-grouping/src\")\n",
"from util.tab_titles import T5TopicGenerator, OnnxT5TopicGenerator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "388913c9-0c7a-473c-8581-a8de4502f1ee",
"metadata": {},
"outputs": [],
"source": [
"def compute_topic_keywords(row, legacy=False, prob_limit=None):\n",
" return topic_gen.get_topic_with_keywords({\"documents\": row[\"three_titles\"].split('\\n'), \"keywords\": row[\"keywords\"].split(',')}, legacy=legacy, prob_limit=prob_limit)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "477ec933-aae1-45fa-9c9a-01a746b329d6",
"metadata": {},
"outputs": [],
"source": [
"topic_gen = T5TopicGenerator(\"./models/still-durian-309\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f295b891-5ec7-4376-806e-54560593f071",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "930c9da7-935f-4f1a-b7f1-b354a63c5668",
"metadata": {},
"outputs": [],
"source": [
"#topic_gen.tokenizer.decode(topic_gen.model.generation_config.bad_words_ids[88])\n",
"\n",
"topic_gen.tokenizer.convert_ids_to_tokens(topic_gen.model.generation_config.bad_words_ids[600])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce161173-8988-4ba4-901f-3d8b1802747a",
"metadata": {},
"outputs": [],
"source": [
"def compute_topic_keywords_single(row, legacy=False, prob_limit=None):\n",
" return topic_gen.get_topic_with_keywords({\"documents\": [row[\"title\"]], \"keywords\": row[\"keywords\"].split(',')}, legacy=legacy, \n",
" prob_limit=prob_limit)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07b9dd63-1203-4293-b501-03c089b475ec",
"metadata": {},
"outputs": [],
"source": [
"def compute_topic(row):\n",
" return topic_gen.get_topic({\"documents\": row[\"three_titles\"].split('\\n')})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bc64325-9f70-40fc-bca7-b15d4ef028d0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9e6e08c-d0d7-482d-a8ed-0f2d37f71a94",
"metadata": {},
"outputs": [],
"source": [
"\n",
"multitab_tests[\"recomputed_titles_keywords\"] = multitab_tests.apply(lambda row: compute_topic_keywords(row), axis=1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1f199de-575c-4d9d-82fb-4d5205ba0614",
"metadata": {},
"outputs": [],
"source": [
"\n",
"torch_models = [\n",
" {\"name\": \"cool-yogurt-98\", \"legacy_data_format\": False},\n",
" {\"name\": \"dainty-blaze-127\", \"legacy_data_format\": False},\n",
" {\"name\": \"dainty-river-189\",\"legacy_data_format\": False},\n",
" {\"name\": \"gallant-sunset-190\",\"legacy_data_format\": False},\n",
" {\"name\": \"upbeat-eon-195\", \"legacy_data_format\": False},\n",
" {\"name\": \"devoted-puddle-246\", \"legacy_data_format\": False},\n",
" {\"name\": \"genial-tree-283\", \"legacy_data_format\": False},\n",
" {\"name\": \"major-elevator-302\", \"legacy_data_format\": False},\n",
" {\"name\": \"olive-silence-303\", \"legacy_data_format\": False},\n",
" {\"name\": \"sandy-forest-305\", \"legacy_data_format\": False},\n",
" {\"name\": \"still-durian-309\", \"legacy_data_format\": False},\n",
" {\"name\": \"eager-plant-323\", \"legacy_data_format\": False},\n",
" {\"name\": \"dulcet-durian-136\", \"legacy_data_format\": False},\n",
" {\"name\": \"lively-planet-17\", \"legacy_data_format\": False},\n",
" {\"name\": \"eager-fog-84\", \"legacy_data_format\": False},\n",
" {\"name\": \"dry-meadow-86\", \"legacy_data_format\": False},\n",
" {\"name\": \"classic-forest-87\", \"legacy_data_format\": False},\n",
" {\"name\": \"laced-terrain-88\", \"legacy_data_format\": False},\n",
" {\"name\": \"drawn-water-93\", \"legacy_data_format\": False}\n",
" ]\n",
"\n",
"onnx_quantized_models = [\n",
" {\"name\": \"cool-yogurt-98\", \"legacy_data_format\": False},\n",
" {\"name\": \"dainty-blaze-127\", \"legacy_data_format\": False},\n",
" {\"name\": \"devoted-puddle-246\", \"legacy_data_format\": False},\n",
" {\"name\": \"sandy-forest-305\", \"legacy_data_format\": False},\n",
" {\"name\": \"still-durian-309\", \"legacy_data_format\": False},\n",
" {\"name\": \"eager-plant-323\", \"legacy_data_format\": False},\n",
" {\"name\": \"eager-fog-84\", \"legacy_data_format\": False},\n",
" {\"name\": \"dry-meadow-86\", \"legacy_data_format\": False},\n",
" {\"name\": \"classic-forest-87\", \"legacy_data_format\": False},\n",
" {\"name\": \"drawn-water-93\", \"legacy_data_format\": False}\n",
" ]\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92866d33-4f4f-4705-9c56-603d27f5feee",
"metadata": {},
"outputs": [],
"source": [
"TEST_ONNX = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9317c4f7-622e-43fb-b58c-7b4a9b191da8",
"metadata": {},
"outputs": [],
"source": [
"models = onnx_quantized_models if TEST_ONNX else torch_models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e79ea79f-62c2-41d8-8377-ffe5099f5194",
"metadata": {},
"outputs": [],
"source": [
"single_tab_tests[\"keywords\"] = pd.Series(dtype=str)\n",
"single_tab_tests = single_tab_tests.fillna(\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f84145f7-53d4-4f82-8503-fbba6d797e2b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "718dc98e-02dd-41c1-b212-5487b9477a0b",
"metadata": {},
"outputs": [],
"source": [
"single_tab_score = []\n",
"multi_tab_score = []\n",
"\n",
"for model_info in models:\n",
" name = model_info[\"name\"]\n",
" topic_gen = OnnxT5TopicGenerator(model_name=f\"./models_onnx/{name}\") if TEST_ONNX else T5TopicGenerator(model_name=f\"./models/{name}\")\n",
" col = f\"recomputed_title_keywords_{name}\"\n",
" multitab_tests[col] = multitab_tests.apply(lambda row: compute_topic_keywords(row, legacy=model_info[\"legacy_data_format\"]), axis=1)\n",
" print(f\"{name} - MultiTab Tests\")\n",
" score = get_avg_scores(multitab_tests, col)\n",
" score[\"model\"] = name\n",
" multi_tab_score.append(score)\n",
" \n",
" single_tab_tests[col] = single_tab_tests.apply(lambda row: compute_topic_keywords_single(row, legacy=model_info[\"legacy_data_format\"]), axis=1)\n",
" print(f\"{name} - Single Tab Tests\")\n",
" score = get_avg_scores(single_tab_tests, col)\n",
" score[\"model\"] = name\n",
" single_tab_score.append(score)\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "750893b6-547c-44d5-bd3b-23162fc8ccb0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf56c4c7-eee9-463e-9439-55952ffd66e2",
"metadata": {},
"outputs": [],
"source": [
"garbled_tests.title.to_list()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afe39759-73a2-4d20-8b6c-553cb562d64c",
"metadata": {},
"outputs": [],
"source": [
"single_tab_df = pd.DataFrame(single_tab_score)\n",
"multi_tab_df = pd.DataFrame(multi_tab_score)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36ae52b0-72c6-47ac-a710-6548173428cf",
"metadata": {},
"outputs": [],
"source": [
"multi_tab_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eeb97658-4222-47c0-b814-4d4a22558f02",
"metadata": {},
"outputs": [],
"source": [
"single_tab_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "622de1a3-9bf8-4f1e-b762-df88478e10db",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}