nbs/flatness.ipynb (71 lines of code) (raw):
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "geographic-personal",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def paired_sort(list1, list2):\n",
" list1, list2 = zip(*sorted(zip(list1, list2)))\n",
" return list1, list2\n",
"\n",
"def plot_phi_by_ckpt():\n",
"\n",
" nums = []\n",
" flatness = []\n",
"\n",
" for f in sorted(os.listdir(\"../results/\")):\n",
" if \"pkl\" in f:\n",
" num = int(f.split(\"-\")[1].split(\".pkl\")[0])\n",
" dat = pickle.load(open(os.path.join(\"../results/\", f), \"rb\"))\n",
" nums.append(num)\n",
" flatness.append(dat[list(dat.keys())[0]].item())\n",
"\n",
" \n",
" nums, flatness = paired_sort(nums, flatness)\n",
" plt.plot(nums, flatness)\n",
" plt.xticks(range(len(nums)))\n",
" plt.ylabel(\"phi\")\n",
" plt.xlabel(\"SD-{n} checkpoint\")\n",
" plt.savefig(\"phi-by-ckpt.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "mineral-assembly",
"metadata": {},
"outputs": [],
"source": [
"plot_phi_by_ckpt()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}