nbs/flatness.ipynb (71 lines of code) (raw):

{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "geographic-personal", "metadata": {}, "outputs": [], "source": [ "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "\n", "def paired_sort(list1, list2):\n", " list1, list2 = zip(*sorted(zip(list1, list2)))\n", " return list1, list2\n", "\n", "def plot_phi_by_ckpt():\n", "\n", " nums = []\n", " flatness = []\n", "\n", " for f in sorted(os.listdir(\"../results/\")):\n", " if \"pkl\" in f:\n", " num = int(f.split(\"-\")[1].split(\".pkl\")[0])\n", " dat = pickle.load(open(os.path.join(\"../results/\", f), \"rb\"))\n", " nums.append(num)\n", " flatness.append(dat[list(dat.keys())[0]].item())\n", "\n", " \n", " nums, flatness = paired_sort(nums, flatness)\n", " plt.plot(nums, flatness)\n", " plt.xticks(range(len(nums)))\n", " plt.ylabel(\"phi\")\n", " plt.xlabel(\"SD-{n} checkpoint\")\n", " plt.savefig(\"phi-by-ckpt.png\")" ] }, { "cell_type": "code", "execution_count": null, "id": "mineral-assembly", "metadata": {}, "outputs": [], "source": [ "plot_phi_by_ckpt()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 5 }