in student_specialization/visualization/visualize.py [0:0]
def figure_success_rate(data):
multis = (1, 2, 5, 10)
thres = 0.95
num_teacher = 20
plt.figure(figsize=(12, 2.5))
# plt.figure()
counter = 0
# fig, ax = plt.subplots(figsize=(6, 5))
for decay in (0.5, 1, 1.5, 2, 2.5):
ax = plt.subplot(1, 5, counter + 1)
counter += 1
for iter, style in zip((5, -1), (':', '-')):
bars = []
ind = torch.FloatTensor(list(range(num_teacher)))
# width = 0.15
colors = ['r', 'g','b','c']
for i, multi in enumerate(multis):
#plt.subplot(1, len(multis), counter)
#counter += 1
d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher))
losses = []
counts = None
for seed, stats in d["stats"].items():
s = stats[iter]
v = (s["counts_eval"][thres] > 0).float()
if counts is None:
counts = v
else:
counts += v
losses.append(s["eval_loss"])
counts /= len(d["stats"])
plt.plot(ind.numpy(), counts.numpy(), colors[i], label=f"{multi}x" if iter == -1 else None, linestyle=style)
# plt.scatter(ind.numpy(), counts.numpy(), color=colors[i])
# plt.title(f"multi={multi}, loss={sum(losses) / len(losses):#.5f}")
# plt.title(f"iter={iter}")
plt.xlabel('Teacher idx')
plt.title(f"$p={decay}$")
plt.axis([-1, num_teacher, 0, 1.1])
if counter == 1:
plt.ylabel('Successful Recovery Rate')
plt.legend()
ticks = ind[::4].numpy()
ax.set_xticks(ticks)
ax.set_xticklabels([ str(int(i)) for i in ticks ])
if counter > 1:
ax.set_yticklabels([])
# ax.legend(bars, [ f"{multi}x" for multi in multis ])
plt.tight_layout()
plt.savefig(f"rate_drop_m{num_teacher}_thres{thres}.pdf")