notebooks/Plotting.ipynb (148 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "eb9a4b5a",
"metadata": {},
"source": [
"# Simple Plotting\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88c7ff9f",
"metadata": {},
"outputs": [],
"source": [
"RESULTS_PATH = \"../../your_sweep_path/default\"\n",
"\n",
"PLOT_ALL_SEEDS = False\n",
"# Full sweep\n",
"MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n",
"# Minimal sweep\n",
"# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00ca073c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style('whitegrid')\n",
"\n",
"from IPython.display import display\n",
"\n",
"import os\n",
"import glob\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5caa051",
"metadata": {},
"outputs": [],
"source": [
"records = []\n",
"for result_filename in glob.glob(os.path.join(RESULTS_PATH, \"**/results_summary.json\"), recursive=True):\n",
" config_file = os.path.join(\"/\".join(result_filename.split(\"/\")[:-1]), \"config.json\")\n",
" config = json.load(open(config_file, \"r\"))\n",
" if config[\"model_size\"] not in MODELS_TO_PLOT:\n",
" continue\n",
" if 'seed' not in config:\n",
" config['seed'] = 0\n",
" record = config.copy()\n",
" if 'weak_model' in config:\n",
" for k in record['weak_model']:\n",
" if k == 'model_size':\n",
" assert record['weak_model'][k] == record['weak_model_size']\n",
" record['weak_' + k] = record['weak_model'][k]\n",
" del record['weak_model']\n",
" record.update(json.load(open(result_filename)))\n",
" records.append(record)\n",
"\n",
"df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'model_size'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f628577",
"metadata": {},
"outputs": [],
"source": [
"datasets = df.ds_name.unique()\n",
"for dataset in datasets:\n",
" cur_df = df[(df.ds_name == dataset)].copy()\n",
" base_accuracies = cur_df[cur_df['weak_model_size'].isna()].groupby('model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
" base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
" base_accuracies = base_accuracies.reset_index()\n",
"\n",
" cur_df['strong_model_accuracy'] = cur_df['model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
" cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_accuracy'] = cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
"\n",
" # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
" valid_pgr_index = (\n",
" (~cur_df['weak_model_size'].isna()) & \n",
" (cur_df['weak_model_size'] != cur_df['model_size']) & \n",
" (cur_df['strong_model_accuracy'] > cur_df['weak_model_accuracy'])\n",
" )\n",
" cur_df.loc[valid_pgr_index, 'pgr'] = (cur_df.loc[valid_pgr_index, 'accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy']) / (cur_df.loc[valid_pgr_index, 'strong_model_accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy'])\n",
"\n",
" cur_df.loc[cur_df['weak_model_size'].isna(), \"weak_model_size\"] = \"ground truth\"\n",
"\n",
" for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
" plot_df = cur_df.copy().sort_values(['strong_model_accuracy']).sort_values(['loss'], ascending=False)\n",
" if seed is not None:\n",
" plot_df = plot_df[plot_df['seed'] == seed]\n",
"\n",
" print(f\"Dataset: {dataset} (seed: {seed})\")\n",
"\n",
" pgr_results = plot_df[~plot_df['pgr'].isna()].groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
"\n",
" palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
" color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
"\n",
" sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
" pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
" plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['model_size']], rotation=90)\n",
" plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
" plt.legend(loc='upper left')\n",
" suffix = \"\"\n",
" if seed is not None:\n",
" suffix = f\"_{seed}\"\n",
" plt.savefig(f\"{dataset.replace('/', '-')}{suffix}.png\", dpi=300, bbox_inches='tight')\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}