def plot_acc_multiple_runs()

in utils/vis_utils.py [0:0]


def plot_acc_multiple_runs(data, task_labels, valid_measures, n_stats, plot_name=None):
    """
    Plots the accuracies
    Args:
        task_labels List of tasks
        n_stats     Number of runs
        plot_name   Name of the file where the plot will be saved

    Returns:
    """
    n_tasks = len(task_labels)
    plt.figure(figsize=(14, 3))
    axs = [plt.subplot(1,n_tasks+1,1)]
    for i in range(1, n_tasks + 1):
        axs.append(plt.subplot(1, n_tasks+1, i+1, sharex=axs[0], sharey=axs[0]))

    fmt_chars = ['o', 's', 'd']
    fmts = []
    for i in range(len(valid_measures)):
        fmts.append(fmt_chars[i%len(fmt_chars)])

    plot_keys = sorted(data['mean'].keys())

    for k, cval in enumerate(plot_keys):
        label = "c=%g"%cval
        mean_vals = data['mean'][cval]
        std_vals = data['std'][cval]
        for j in range(n_tasks+1):
            plt.sca(axs[j])
            errorbar_kwargs = dict(fmt="%s-"%fmts[k], markersize=5)
            if j < n_tasks:
                norm= np.sqrt(n_stats) # np.sqrt(n_stats) for SEM or 1 for STDEV
                axs[j].errorbar(np.arange(n_tasks)+1, mean_vals[:, j], yerr=std_vals[:, j]/norm, label=label, **errorbar_kwargs)
            else:
                mean_stuff = []
                std_stuff = []
                for i in range(len(data['mean'][cval])):
                    mean_stuff.append(data['mean'][cval][i][:i+1].mean())
                    std_stuff.append(np.sqrt((data['std'][cval][i][:i+1]**2).sum())/(n_stats*np.sqrt(n_stats)))
                plt.errorbar(range(1,n_tasks+1), mean_stuff, yerr=std_stuff, label="%s"%valid_measures[k], **errorbar_kwargs)
            plt.xticks(np.arange(n_tasks)+1)
            plt.xlim((1.0,5.5))
            """
            # Uncomment this if clutter along y-axis needs to be removed
            if j == 0:
                axs[j].set_yticks([0.5,1])
            else:
                plt.setp(axs[j].get_yticklabels(), visible=False)
            plt.ylim((0.45,1.1))
            """

    for i, ax in enumerate(axs):
        if i < n_tasks:
            ax.set_title((['Task %d (%d to %d)'%(j+1,task_labels[j][0], task_labels[j][-1])\
                           for j in range(n_tasks)] + ['average'])[i], fontsize=8)
        else:
            ax.set_title("Average", fontsize=8)
        ax.axhline(0.5, color='k', linestyle=':', label="chance", zorder=0)

    handles, labels = axs[-1].get_legend_handles_labels()

    # Reorder legend so chance is last
    axs[-1].legend([handles[j] for j in [i for i in range(len(valid_measures)+1)]], 
                    [labels[j] for j in [i for i in range(len(valid_measures)+1)]], loc='best', fontsize=6)

    axs[0].set_xlabel("Tasks")
    axs[0].set_ylabel("Accuracy")
    plt.gcf().tight_layout()
    plt.grid('on')
    if plot_name == None:
        plt.show()
    else:
        plt.savefig(plot_name)