def main()

in sample_info/scripts/aggregate_data_summarization_results.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    parser.add_argument('--root_dir', '-r', type=str,
                        default='sample_info/results/data-summarization/')
    parser.add_argument('--baselines', '-b', type=str, nargs='+', required=True)
    parser.add_argument('--num_examples', '-n', type=int, required=True)
    args = parser.parse_args()
    print(args)

    # storage for all methods
    results = defaultdict(dict)

    # random baseline
    if 'random' in args.baselines:
        read_random(results, args)

    # predictions top
    if 'predictions-top' in args.baselines:
        read_proposed(results, args, 'predictions-top')

    # predictions bottom
    if 'predictions-bottom' in args.baselines:
        read_proposed(results, args, 'predictions-bottom')

    # weights top
    if 'weights-plain-top' in args.baselines:
        read_proposed(results, args, 'weights-plain-top')

    # weights bottom
    if 'weights-plain-bottom' in args.baselines:
        read_proposed(results, args, 'weights-plain-bottom')

    # predictions iterative
    if 'predictions-iterative' in args.baselines:
        read_proposed(results, args, 'predictions-iterative')

    # weights iterative
    if 'weights-plain-iterative' in args.baselines:
        read_proposed(results, args, 'weights-plain-iterative')

    # plot
    fig, ax = plt.subplots(figsize=(7, 5))

    if 'random' in results:
        cur_results = results.pop('random')
        means = np.array([np.mean(y) for y in cur_results['ys']])
        stds = np.array([np.std(y) for y in cur_results['ys']])
        ax.plot(cur_results['xs'], means, label='random')
        ax.fill_between(cur_results['xs'], means - stds, means + stds, alpha=0.2)

    rename_dict = {
        'predictions-iterative': 'bottom (iterative)',
        'predictions-top': 'top',
        'predictions-bottom': 'bottom'
    }

    for baseline_name, cur_results in results.items():
        ax.plot(cur_results['xs'], cur_results['ys'], label=rename_dict.get(baseline_name, baseline_name))

    ax.set_xlabel('Ratio of removed examples')
    ax.set_ylabel('Test accuracy')
    ax.legend()
    fig.tight_layout()
    save_path = os.path.join(args.root_dir, 'aggregated', args.exp_name, 'plot.pdf')
    savefig(fig, save_path)