in sample_info/scripts/aggregate_data_summarization_results.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', '-E', type=str, required=True)
parser.add_argument('--root_dir', '-r', type=str,
default='sample_info/results/data-summarization/')
parser.add_argument('--baselines', '-b', type=str, nargs='+', required=True)
parser.add_argument('--num_examples', '-n', type=int, required=True)
args = parser.parse_args()
print(args)
# storage for all methods
results = defaultdict(dict)
# random baseline
if 'random' in args.baselines:
read_random(results, args)
# predictions top
if 'predictions-top' in args.baselines:
read_proposed(results, args, 'predictions-top')
# predictions bottom
if 'predictions-bottom' in args.baselines:
read_proposed(results, args, 'predictions-bottom')
# weights top
if 'weights-plain-top' in args.baselines:
read_proposed(results, args, 'weights-plain-top')
# weights bottom
if 'weights-plain-bottom' in args.baselines:
read_proposed(results, args, 'weights-plain-bottom')
# predictions iterative
if 'predictions-iterative' in args.baselines:
read_proposed(results, args, 'predictions-iterative')
# weights iterative
if 'weights-plain-iterative' in args.baselines:
read_proposed(results, args, 'weights-plain-iterative')
# plot
fig, ax = plt.subplots(figsize=(7, 5))
if 'random' in results:
cur_results = results.pop('random')
means = np.array([np.mean(y) for y in cur_results['ys']])
stds = np.array([np.std(y) for y in cur_results['ys']])
ax.plot(cur_results['xs'], means, label='random')
ax.fill_between(cur_results['xs'], means - stds, means + stds, alpha=0.2)
rename_dict = {
'predictions-iterative': 'bottom (iterative)',
'predictions-top': 'top',
'predictions-bottom': 'bottom'
}
for baseline_name, cur_results in results.items():
ax.plot(cur_results['xs'], cur_results['ys'], label=rename_dict.get(baseline_name, baseline_name))
ax.set_xlabel('Ratio of removed examples')
ax.set_ylabel('Test accuracy')
ax.legend()
fig.tight_layout()
save_path = os.path.join(args.root_dir, 'aggregated', args.exp_name, 'plot.pdf')
savefig(fig, save_path)