def analyze_results()

in domainbed_measures/extract_generalization_features.py [0:0]


def analyze_results(run_dir, datasets):
    experiment_path_list = datasets.split(",")
    print("Running extraction on following datasets:")
    print(experiment_path_list)

    common_feature_names = None
    all_results = {}
    for experiment_path in experiment_path_list:
        results, feature_names = load_results(run_dir, experiment_path)
        all_results[experiment_path] = results

        if common_feature_names == None:
            common_feature_names = set(feature_names)
        else:
            common_feature_names = common_feature_names.intersection(
                set(feature_names))

    if common_feature_names == None:
        raise ValueError
    # Add within domain val accuracy as an additional feature
    if args.canonicalize == True:
        common_feature_names = common_feature_names.intersection(
            set(CANONICALIZATION.keys()))
    common_feature_names = list(common_feature_names)
    common_feature_names.append('wd_out_domain_err')
    common_feature_names = tuple(sorted(common_feature_names))

    print("Found the following common features across all datasets for OOD")
    for f in common_feature_names:
        print(f)

    all_headers = []
    all_rows = []

    for experiment_path in experiment_path_list:
        header, rows = experiment_per_dataset(all_results[experiment_path],
                                              common_feature_names,
                                              experiment_path,
                                              wd_or_ood='ood')
        all_rows.extend(rows)
        all_headers.append(header)

    with open(
            os.path.join(args.run_dir, 'sweeps_%s_canon_%s_ood.csv') %
        (datasets, args.canonicalize), 'w') as f:
        writer = csv.writer(f)
        writer.writerow(all_headers[0])
        for r in all_rows:
            writer.writerow(r)

    common_feature_names = list(common_feature_names)
    common_feature_names.remove('wd_out_domain_err')
    common_feature_names = tuple(common_feature_names)
    print("Found the following common features across all datasets for WD")
    for f in common_feature_names:
        print(f)

    all_headers = []
    all_rows = []

    for experiment_path in experiment_path_list:
        header, rows = experiment_per_dataset(all_results[experiment_path],
                                              common_feature_names,
                                              experiment_path,
                                              wd_or_ood='wd')
        all_rows.extend(rows)
        all_headers.append(header)

    with open(
            os.path.join(args.run_dir, 'sweeps_%s_canon_%s_wd.csv') %
        (datasets, args.canonicalize), 'w') as f:
        writer = csv.writer(f)
        writer.writerow(all_headers[0])
        for r in all_rows:
            writer.writerow(r)