notebooks/Benchmarking.ipynb (524 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "1d01dd88-caac-43d8-a4c2-f1059ca8ee59", "metadata": {}, "source": [ "This notebook benchmarks ML models (downloaded into the model directory) against various validation sets.\n", "\n", "The 'all_users' dataset may not be public, but the single tab dataset is.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "581755c5-3d91-4c1f-b01e-40d9de375c66", "metadata": { "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "import pandas as pd\n", "from pandas import DataFrame\n", "from functools import partial" ] }, { "cell_type": "code", "execution_count": null, "id": "cb473ed7-aa9e-42fd-9446-a081d9ba9c38", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%pwd\n", "%cd \"~/Documents/GitHub/smart-tab-grouping\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d82dbb51-bb12-46c4-b532-87a2735caccb", "metadata": {}, "outputs": [], "source": [ "from rouge_score import rouge_scorer" ] }, { "cell_type": "code", "execution_count": null, "id": "e3c8d2f3-2be3-4dc8-9461-9ed71947a0e0", "metadata": {}, "outputs": [], "source": [ "multitab_tests = pd.read_csv(\"data/individual_tests/private/all_users2.csv\")\n", "single_tab_tests = pd.read_csv(\"data/individual_tests/single_tab_validation.csv\")\n", "single_tab_tests.keywords = \"\"\n", "\n", "garbled_tests = pd.read_csv(\"data/individual_tests/garbled.csv\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f4a81bc3-8bf0-46ab-aebb-c3b1f9b3e2ed", "metadata": {}, "outputs": [], "source": [ "garbled_tests.loc[:, \"keywords\"] = \"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "f60698b7-3cb8-484b-978f-798955a1f729", "metadata": {}, "outputs": [], "source": [ "garbled_tests" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e68219-1f21-4620-a734-09fd140bf8ae", "metadata": {}, "outputs": [], "source": [ "from spellchecker import SpellChecker\n", "spell = SpellChecker()\n", "spell.word_frequency.load_words(['microsoft', 'apple', 'google', 'bing', 'search', 'duckduckgo', 'yahoo'])\n", "\n", "\n", "def is_clean_string(s: str):\n", " for word in s.split():\n", " if (\"'\" in word):\n", " segments = word.split(\"'\")\n", " if len(segments) == 1:\n", " break\n", " if len(segments) > 2:\n", " return False\n", " if len(segments) == 2:\n", " if len(segments[0]) > 1 and len(segments[1]) > 1:\n", " return False\n", " continue # don't check spelling with 's\n", " if (len(spell.unknown([word])) == 1):\n", " return False\n", " last_char = None\n", " for cur_char in word:\n", " if last_char is None:\n", " last_char = cur_char\n", " continue\n", " if (not last_char.isalpha()) or (not cur_char.isalpha()):\n", " last_char = cur_char\n", " continue\n", " if cur_char.upper() == cur_char and last_char.lower() == last_char: # switch to uppercase\n", " return False\n", " last_char = cur_char\n", " return True\n", " \n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "0e711a63-5ece-4b73-a757-0cba0548572c", "metadata": {}, "outputs": [], "source": [ "scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "f6c987e5-aad7-4d3c-a2df-8696b96b7db4", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics.pairwise import cosine_similarity\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8bd3e3c8-907c-4084-a73c-3407d4906a00", "metadata": {}, "outputs": [], "source": [ "from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "id": "575f3907-0cac-44a8-ad32-41bb01077c84", "metadata": {}, "outputs": [], "source": [ "embedder = pipeline(\"feature-extraction\", model=\"sentence-transformers/all-MiniLM-L6-v2\", device=-1)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b485fdef-ee22-441a-9cf5-45d8ef571a30", "metadata": {}, "outputs": [], "source": [ "def cos_sim(s1, s2):\n", " embeddings = [np.mean(embedder(s)[0], axis=0) for s in [s1, s2]]\n", " similarity = cosine_similarity(embeddings[0].reshape(1,-1), embeddings[1].reshape(1,-1)).squeeze()\n", " return similarity\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "fc7b43aa-0adb-4c63-a1c7-47478d24e405", "metadata": {}, "outputs": [], "source": [ "cos_sim(\"Dogs\", \"Apple\")" ] }, { "cell_type": "code", "execution_count": null, "id": "51e52356-e47c-4910-863a-bec23e119ad5", "metadata": {}, "outputs": [], "source": [ "def compute_scores(row, pred_key=None):\n", " scores = scorer.score(row['label'], row[pred_key])\n", " return {\n", " 'rouge1': scores['rouge1'].fmeasure,\n", " 'rouge2': scores['rouge2'].fmeasure,\n", " 'rougeL': scores['rougeL'].fmeasure,\n", " 'pred_len': len(row[pred_key]),\n", " 'label_len': len(row['label']),\n", " 'cos_sim': cos_sim(row['label'], row[pred_key]),\n", " 'clean': 1 if is_clean_string(row[pred_key]) else 0\n", " }\n", "\n", "def compute_scores_no_label(row, pred_key=None):\n", " return {\n", " 'clean': 1 if is_clean_string(row[pred_key]) else 0\n", " }\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c14856f7-8329-419c-9e2e-1c96a07d3984", "metadata": {}, "outputs": [], "source": [ "def get_avg_scores(input_df: DataFrame, compare_column: str):\n", " scorer = compute_scores_no_label if 'label' not in input_df.columns else compute_scores\n", " rouge_scores_df = input_df.apply(partial(scorer, pred_key=compare_column) , axis=1, result_type='expand')\n", " average_scores = rouge_scores_df.mean().to_dict()\n", " return average_scores\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fc08ad0e-5750-4823-8181-6b6c58bdb11b", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"/Users/Rrando/Documents/GitHub/smart-tab-grouping/src\")\n", "from util.tab_titles import T5TopicGenerator, OnnxT5TopicGenerator" ] }, { "cell_type": "code", "execution_count": null, "id": "388913c9-0c7a-473c-8581-a8de4502f1ee", "metadata": {}, "outputs": [], "source": [ "def compute_topic_keywords(row, legacy=False, prob_limit=None):\n", " return topic_gen.get_topic_with_keywords({\"documents\": row[\"three_titles\"].split('\\n'), \"keywords\": row[\"keywords\"].split(',')}, legacy=legacy, prob_limit=prob_limit)" ] }, { "cell_type": "code", "execution_count": null, "id": "477ec933-aae1-45fa-9c9a-01a746b329d6", "metadata": {}, "outputs": [], "source": [ "topic_gen = T5TopicGenerator(\"./models/still-durian-309\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f295b891-5ec7-4376-806e-54560593f071", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "930c9da7-935f-4f1a-b7f1-b354a63c5668", "metadata": {}, "outputs": [], "source": [ "#topic_gen.tokenizer.decode(topic_gen.model.generation_config.bad_words_ids[88])\n", "\n", "topic_gen.tokenizer.convert_ids_to_tokens(topic_gen.model.generation_config.bad_words_ids[600])" ] }, { "cell_type": "code", "execution_count": null, "id": "ce161173-8988-4ba4-901f-3d8b1802747a", "metadata": {}, "outputs": [], "source": [ "def compute_topic_keywords_single(row, legacy=False, prob_limit=None):\n", " return topic_gen.get_topic_with_keywords({\"documents\": [row[\"title\"]], \"keywords\": row[\"keywords\"].split(',')}, legacy=legacy, \n", " prob_limit=prob_limit)" ] }, { "cell_type": "code", "execution_count": null, "id": "07b9dd63-1203-4293-b501-03c089b475ec", "metadata": {}, "outputs": [], "source": [ "def compute_topic(row):\n", " return topic_gen.get_topic({\"documents\": row[\"three_titles\"].split('\\n')})" ] }, { "cell_type": "code", "execution_count": null, "id": "0bc64325-9f70-40fc-bca7-b15d4ef028d0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d9e6e08c-d0d7-482d-a8ed-0f2d37f71a94", "metadata": {}, "outputs": [], "source": [ "\n", "multitab_tests[\"recomputed_titles_keywords\"] = multitab_tests.apply(lambda row: compute_topic_keywords(row), axis=1)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d1f199de-575c-4d9d-82fb-4d5205ba0614", "metadata": {}, "outputs": [], "source": [ "\n", "torch_models = [\n", " {\"name\": \"cool-yogurt-98\", \"legacy_data_format\": False},\n", " {\"name\": \"dainty-blaze-127\", \"legacy_data_format\": False},\n", " {\"name\": \"dainty-river-189\",\"legacy_data_format\": False},\n", " {\"name\": \"gallant-sunset-190\",\"legacy_data_format\": False},\n", " {\"name\": \"upbeat-eon-195\", \"legacy_data_format\": False},\n", " {\"name\": \"devoted-puddle-246\", \"legacy_data_format\": False},\n", " {\"name\": \"genial-tree-283\", \"legacy_data_format\": False},\n", " {\"name\": \"major-elevator-302\", \"legacy_data_format\": False},\n", " {\"name\": \"olive-silence-303\", \"legacy_data_format\": False},\n", " {\"name\": \"sandy-forest-305\", \"legacy_data_format\": False},\n", " {\"name\": \"still-durian-309\", \"legacy_data_format\": False},\n", " {\"name\": \"eager-plant-323\", \"legacy_data_format\": False},\n", " {\"name\": \"dulcet-durian-136\", \"legacy_data_format\": False},\n", " {\"name\": \"lively-planet-17\", \"legacy_data_format\": False},\n", " {\"name\": \"eager-fog-84\", \"legacy_data_format\": False},\n", " {\"name\": \"dry-meadow-86\", \"legacy_data_format\": False},\n", " {\"name\": \"classic-forest-87\", \"legacy_data_format\": False},\n", " {\"name\": \"laced-terrain-88\", \"legacy_data_format\": False},\n", " {\"name\": \"drawn-water-93\", \"legacy_data_format\": False}\n", " ]\n", "\n", "onnx_quantized_models = [\n", " {\"name\": \"cool-yogurt-98\", \"legacy_data_format\": False},\n", " {\"name\": \"dainty-blaze-127\", \"legacy_data_format\": False},\n", " {\"name\": \"devoted-puddle-246\", \"legacy_data_format\": False},\n", " {\"name\": \"sandy-forest-305\", \"legacy_data_format\": False},\n", " {\"name\": \"still-durian-309\", \"legacy_data_format\": False},\n", " {\"name\": \"eager-plant-323\", \"legacy_data_format\": False},\n", " {\"name\": \"eager-fog-84\", \"legacy_data_format\": False},\n", " {\"name\": \"dry-meadow-86\", \"legacy_data_format\": False},\n", " {\"name\": \"classic-forest-87\", \"legacy_data_format\": False},\n", " {\"name\": \"drawn-water-93\", \"legacy_data_format\": False}\n", " ]\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "92866d33-4f4f-4705-9c56-603d27f5feee", "metadata": {}, "outputs": [], "source": [ "TEST_ONNX = True" ] }, { "cell_type": "code", "execution_count": null, "id": "9317c4f7-622e-43fb-b58c-7b4a9b191da8", "metadata": {}, "outputs": [], "source": [ "models = onnx_quantized_models if TEST_ONNX else torch_models" ] }, { "cell_type": "code", "execution_count": null, "id": "e79ea79f-62c2-41d8-8377-ffe5099f5194", "metadata": {}, "outputs": [], "source": [ "single_tab_tests[\"keywords\"] = pd.Series(dtype=str)\n", "single_tab_tests = single_tab_tests.fillna(\"\")" ] }, { "cell_type": "code", "execution_count": null, "id": "f84145f7-53d4-4f82-8503-fbba6d797e2b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "718dc98e-02dd-41c1-b212-5487b9477a0b", "metadata": {}, "outputs": [], "source": [ "single_tab_score = []\n", "multi_tab_score = []\n", "\n", "for model_info in models:\n", " name = model_info[\"name\"]\n", " topic_gen = OnnxT5TopicGenerator(model_name=f\"./models_onnx/{name}\") if TEST_ONNX else T5TopicGenerator(model_name=f\"./models/{name}\")\n", " col = f\"recomputed_title_keywords_{name}\"\n", " multitab_tests[col] = multitab_tests.apply(lambda row: compute_topic_keywords(row, legacy=model_info[\"legacy_data_format\"]), axis=1)\n", " print(f\"{name} - MultiTab Tests\")\n", " score = get_avg_scores(multitab_tests, col)\n", " score[\"model\"] = name\n", " multi_tab_score.append(score)\n", " \n", " single_tab_tests[col] = single_tab_tests.apply(lambda row: compute_topic_keywords_single(row, legacy=model_info[\"legacy_data_format\"]), axis=1)\n", " print(f\"{name} - Single Tab Tests\")\n", " score = get_avg_scores(single_tab_tests, col)\n", " score[\"model\"] = name\n", " single_tab_score.append(score)\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "750893b6-547c-44d5-bd3b-23162fc8ccb0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "cf56c4c7-eee9-463e-9439-55952ffd66e2", "metadata": {}, "outputs": [], "source": [ "garbled_tests.title.to_list()" ] }, { "cell_type": "code", "execution_count": null, "id": "afe39759-73a2-4d20-8b6c-553cb562d64c", "metadata": {}, "outputs": [], "source": [ "single_tab_df = pd.DataFrame(single_tab_score)\n", "multi_tab_df = pd.DataFrame(multi_tab_score)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "36ae52b0-72c6-47ac-a710-6548173428cf", "metadata": {}, "outputs": [], "source": [ "multi_tab_df" ] }, { "cell_type": "code", "execution_count": null, "id": "eeb97658-4222-47c0-b814-4d4a22558f02", "metadata": {}, "outputs": [], "source": [ "single_tab_df" ] }, { "cell_type": "code", "execution_count": null, "id": "622de1a3-9bf8-4f1e-b762-df88478e10db", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }