in scripts/create_metric_graphs.py [0:0]
def create_tsne_graphs(operation, expt, run_dir, image_dir=args.image_dir):
saved_pt_dir = f"{run_dir}/activations"
saved_pts = []
loss_ts = []
accuracy_ts = []
epochs_ts = []
print(f'glob = {saved_pt_dir + "/activations_*.pt"}')
files = sorted(glob.glob(saved_pt_dir + "/activations_*.pt"))
print(f"files = {files}")
for file in files:
print(f"Loading {file}")
saved_pt = torch.load(file)
saved_pts.append(saved_pt)
loss_ts.append(saved_pt["val_loss"].mean(dim=-1))
accuracy_ts.append(saved_pt["val_accuracy"])
epochs_ts.append(saved_pt["epochs"].squeeze())
loss_t = torch.cat(loss_ts, dim=0).T.detach()
accuracy_t = torch.cat(accuracy_ts, dim=0).T.detach()
epochs_t = torch.cat(epochs_ts, dim=0).detach()
print(loss_t.shape)
print(accuracy_t.shape)
print(epochs_t.shape)
######
a = 0
num_eqs = len(loss_t)
b = a + num_eqs
print("Doing T-SNE..")
loss_tsne = TSNE(n_components=2, init="pca").fit_transform(loss_t)
print("...done T-SNE.")
ncols = 1
nrows = 1
fig_width = ncols * 8
fig_height = nrows * 5
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_width, fig_height))
axs.scatter(loss_tsne[:, 0], loss_tsne[:, 1])
img_file = f"{image_dir}/tsne/{operation}_{expt}.png"
d = os.path.split(img_file)[0]
os.makedirs(d, exist_ok=True)
print(f"Writing {img_file}")
fig.savefig(img_file)
plt.close(fig)