in student_specialization/visualization/visualize.py [0:0]
def figure_loss(data):
multis = (1, 2, 5, 10)
decays = (0, 0.5, 1, 1.5, 2, 2.5)
num_teacher = 20
plt.figure(figsize=(15, 7))
# plt.figure()
counter = 1
# fig, ax = plt.subplots(figsize=(6, 5))
for decay in decays:
ax = plt.subplot(2, len(decays) / 2, counter)
counter += 1
for i, multi in enumerate(multis):
d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher))
losses = None
for j, (seed, stats) in enumerate(d["stats"].items()):
v = torch.DoubleTensor([ math.log(s["eval_loss"]) / math.log(10.0) for s in stats ])
if losses is None:
losses = torch.DoubleTensor(len(stats), len(d["stats"]))
losses[:, j] = v
loss = losses.mean(dim=1)
loss_std = losses.std(dim=1)
p = plt.plot(loss.numpy(), label=f"{multi}x")
plt.fill_between(list(range(loss.size(0))), (loss - loss_std).numpy(), (loss + loss_std).numpy(), color=p[0].get_color(), alpha=0.2)
if counter >= 5:
plt.xlabel('Epoch')
if counter == 2 or counter == 5:
plt.ylabel('Evaluation log loss')
else:
ax.set_yticklabels([])
plt.title(f"$p={decay}$")
plt.axis([0, 100, -8, 0])
if counter == 2:
plt.legend()
plt.savefig(f"convergence_m{num_teacher}.pdf")