def figure_success_rate()

in student_specialization/visualization/visualize.py [0:0]


def figure_success_rate(data):
    multis = (1, 2, 5, 10)
    thres = 0.95
    num_teacher = 20

    plt.figure(figsize=(12, 2.5))
    # plt.figure()

    counter = 0

    # fig, ax = plt.subplots(figsize=(6, 5))
    for decay in (0.5, 1, 1.5, 2, 2.5):
        ax = plt.subplot(1, 5, counter + 1)
        counter += 1
        for iter, style in zip((5, -1), (':', '-')):
            bars = []
            ind = torch.FloatTensor(list(range(num_teacher)))
            # width = 0.15
            colors = ['r', 'g','b','c']
            for i, multi in enumerate(multis):
                #plt.subplot(1, len(multis), counter)
                #counter += 1

                d = find_params(data, dict(multi=multi, teacher_strength_decay=decay, m=num_teacher))

                losses = []

                counts = None
                for seed, stats in d["stats"].items():
                    s = stats[iter]
                    v = (s["counts_eval"][thres] > 0).float()
                    if counts is None:
                        counts = v
                    else:
                        counts += v

                    losses.append(s["eval_loss"])

                counts /= len(d["stats"])
                plt.plot(ind.numpy(), counts.numpy(), colors[i], label=f"{multi}x" if iter == -1 else None, linestyle=style)
                # plt.scatter(ind.numpy(), counts.numpy(), color=colors[i])

            # plt.title(f"multi={multi}, loss={sum(losses) / len(losses):#.5f}")
            # plt.title(f"iter={iter}")

        plt.xlabel('Teacher idx')
        plt.title(f"$p={decay}$")
        plt.axis([-1, num_teacher, 0, 1.1])
        if counter == 1:
            plt.ylabel('Successful Recovery Rate')
            plt.legend()
        
        ticks = ind[::4].numpy()

        ax.set_xticks(ticks)
        ax.set_xticklabels([ str(int(i)) for i in ticks ])
        if counter > 1:
            ax.set_yticklabels([])

            # ax.legend(bars, [ f"{multi}x" for multi in multis ])
            
    plt.tight_layout()
        
    plt.savefig(f"rate_drop_m{num_teacher}_thres{thres}.pdf")