in domainbed_measures/extract_generalization_features.py [0:0]
def analyze_results(run_dir, datasets):
experiment_path_list = datasets.split(",")
print("Running extraction on following datasets:")
print(experiment_path_list)
common_feature_names = None
all_results = {}
for experiment_path in experiment_path_list:
results, feature_names = load_results(run_dir, experiment_path)
all_results[experiment_path] = results
if common_feature_names == None:
common_feature_names = set(feature_names)
else:
common_feature_names = common_feature_names.intersection(
set(feature_names))
if common_feature_names == None:
raise ValueError
# Add within domain val accuracy as an additional feature
if args.canonicalize == True:
common_feature_names = common_feature_names.intersection(
set(CANONICALIZATION.keys()))
common_feature_names = list(common_feature_names)
common_feature_names.append('wd_out_domain_err')
common_feature_names = tuple(sorted(common_feature_names))
print("Found the following common features across all datasets for OOD")
for f in common_feature_names:
print(f)
all_headers = []
all_rows = []
for experiment_path in experiment_path_list:
header, rows = experiment_per_dataset(all_results[experiment_path],
common_feature_names,
experiment_path,
wd_or_ood='ood')
all_rows.extend(rows)
all_headers.append(header)
with open(
os.path.join(args.run_dir, 'sweeps_%s_canon_%s_ood.csv') %
(datasets, args.canonicalize), 'w') as f:
writer = csv.writer(f)
writer.writerow(all_headers[0])
for r in all_rows:
writer.writerow(r)
common_feature_names = list(common_feature_names)
common_feature_names.remove('wd_out_domain_err')
common_feature_names = tuple(common_feature_names)
print("Found the following common features across all datasets for WD")
for f in common_feature_names:
print(f)
all_headers = []
all_rows = []
for experiment_path in experiment_path_list:
header, rows = experiment_per_dataset(all_results[experiment_path],
common_feature_names,
experiment_path,
wd_or_ood='wd')
all_rows.extend(rows)
all_headers.append(header)
with open(
os.path.join(args.run_dir, 'sweeps_%s_canon_%s_wd.csv') %
(datasets, args.canonicalize), 'w') as f:
writer = csv.writer(f)
writer.writerow(all_headers[0])
for r in all_rows:
writer.writerow(r)