def figure_loss()

in student_specialization/visualization/visualize.py [0:0]


def figure_loss(data):
    multis = (1, 2, 5, 10)
    decays = (0, 0.5, 1, 1.5, 2, 2.5)
    num_teacher = 20

    plt.figure(figsize=(15, 7))
    # plt.figure()

    counter = 1

    # fig, ax = plt.subplots(figsize=(6, 5))
    for decay in decays:
        ax = plt.subplot(2, len(decays) / 2, counter)
        counter += 1
        for i, multi in enumerate(multis):
            d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher))
            losses = None
            for j, (seed, stats) in enumerate(d["stats"].items()):
                v = torch.DoubleTensor([ math.log(s["eval_loss"]) / math.log(10.0) for s in stats ])
                if losses is None:
                    losses = torch.DoubleTensor(len(stats), len(d["stats"]))
                losses[:, j] = v
                
            loss = losses.mean(dim=1)
            loss_std = losses.std(dim=1)
            p = plt.plot(loss.numpy(), label=f"{multi}x")
            plt.fill_between(list(range(loss.size(0))), (loss - loss_std).numpy(), (loss + loss_std).numpy(), color=p[0].get_color(), alpha=0.2)

        if counter >= 5:
            plt.xlabel('Epoch')
                
        if counter == 2 or counter == 5:           
            plt.ylabel('Evaluation log loss')
        else:
            ax.set_yticklabels([])
                
        plt.title(f"$p={decay}$")
        plt.axis([0, 100, -8, 0])

        if counter == 2:
            plt.legend()
        
    plt.savefig(f"convergence_m{num_teacher}.pdf")