in sample_info/scripts/aggregate_ground_truth_results.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', '-E', type=str, required=True)
parser.add_argument('--root_dir', '-r', type=str,
default='sample_info/results/ground-truth/')
parser.add_argument('--num_examples', '-n', type=int, required=True)
args = parser.parse_args()
print(args)
# storage for all methods
results = defaultdict(lambda: defaultdict(list))
# read ground truths
mask = read_ground_truth(args, results)
# read proposed
read_proposed(args, results)
# read influence functions
influence_functions = read_influence_functions(args, results)
# plot
keys = ['weights_diff', 'pred_diff']
for key in keys:
norms = dict()
for method, D in results.items():
vectors = D[key]
cur_norms = [torch.sum(x**2) for idx, x in enumerate(vectors) if mask[idx]]
cur_norms = torch.stack(cur_norms).flatten()
norms[method] = utils.to_numpy(cur_norms)
fig, ax = plt.subplots()
vmin = np.min(norms['ground-truth'])
vmax = np.max(norms['ground-truth'])
ax.set_title(key)
ax.scatter(norms['ground-truth'], norms['proposed'], label='gt-vs-proposed', s=5)
ax.set_xlim(left=vmin, right=vmax)
ax.set_ylim(bottom=vmin, top=vmax)
# ax.scatter(norms['ground-truth'], norms['influence-functions'], label='gt-vs-influence', s=5)
ax.set_xlabel('ground_truth')
ax.legend()
fig.tight_layout()
save_path = os.path.join(args.root_dir, 'aggregated', args.exp_name, f'{key}-norm-scatter.pdf')
savefig(fig, save_path)
print("Correlations of proposed:")
print(np.corrcoef(norms['ground-truth'], norms['proposed']))
if influence_functions:
print("Correlations of influence functions:")
print(np.corrcoef(norms['ground-truth'], norms['influence-functions']))