def main()

in sample_info/scripts/aggregate_ground_truth_results.py [0:0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', '-E', type=str, required=True)
    parser.add_argument('--root_dir', '-r', type=str,
                        default='sample_info/results/ground-truth/')
    parser.add_argument('--num_examples', '-n', type=int, required=True)
    args = parser.parse_args()
    print(args)

    # storage for all methods
    results = defaultdict(lambda: defaultdict(list))

    # read ground truths
    mask = read_ground_truth(args, results)

    # read proposed
    read_proposed(args, results)

    # read influence functions
    influence_functions = read_influence_functions(args, results)

    # plot
    keys = ['weights_diff', 'pred_diff']
    for key in keys:
        norms = dict()

        for method, D in results.items():
            vectors = D[key]
            cur_norms = [torch.sum(x**2) for idx, x in enumerate(vectors) if mask[idx]]
            cur_norms = torch.stack(cur_norms).flatten()
            norms[method] = utils.to_numpy(cur_norms)

        fig, ax = plt.subplots()
        vmin = np.min(norms['ground-truth'])
        vmax = np.max(norms['ground-truth'])
        ax.set_title(key)
        ax.scatter(norms['ground-truth'], norms['proposed'], label='gt-vs-proposed', s=5)
        ax.set_xlim(left=vmin, right=vmax)
        ax.set_ylim(bottom=vmin, top=vmax)
        # ax.scatter(norms['ground-truth'], norms['influence-functions'], label='gt-vs-influence', s=5)
        ax.set_xlabel('ground_truth')
        ax.legend()
        fig.tight_layout()
        save_path = os.path.join(args.root_dir, 'aggregated', args.exp_name, f'{key}-norm-scatter.pdf')
        savefig(fig, save_path)

        print("Correlations of proposed:")
        print(np.corrcoef(norms['ground-truth'], norms['proposed']))

        if influence_functions:
            print("Correlations of influence functions:")
            print(np.corrcoef(norms['ground-truth'], norms['influence-functions']))