in experiments/sample_datasets.py [0:0]
def main():
args = parser.parse_args()
data_dir = args.data_dir
baseline_dir = args.baseline_dir
precision = args.precision
num_corr = args.num
out_file = args.out
log_name = args.log_name
target_error = args.target
baseline_exclusions = ['saturate', 'spatter', 'gaussian_blur', 'speckle_noise']
corr_exclusions = []
print("Loading data...")
data_dirs = data_dir.split(",")
baseline_dirs = baseline_dir.split(",")
corr_errs, corr_features = get_data(data_dirs, corr_exclusions, log_file=log_name)
baseline_errs, baseline_features = get_data(baseline_dirs, exclusions=baseline_exclusions, log_file=log_name)
baseline_table = build_baseline_features(baseline_features)
avg_spread = get_average_spread(baseline_errs)
corr_sets = build_sets(corr_errs, avg_spread)
corrs, ordered_sev_list, err_table, feat_table = build_corr_tables(corr_sets, corr_errs, corr_features)
dists = build_distance_table(baseline_table, feat_table)
chosen = sample_matched_corruptions(err_table, target_error, precision, num_corr)
out = []
for aug_list in chosen:
sub_aug_strings = []
err = 0.0
curr_dists = None
for a in aug_list:
corr = corrs[a[0]]
sevs = ordered_sev_list[corr][a[1]]
sub_aug_strings.append("--".join(["{}-{}".format(corr,s) for s in sevs]))
err += err_table[a[0], a[1]]
curr_curr_dists = dists[a[0], a[1]]
curr_dists = np.concatenate((curr_dists, curr_curr_dists.reshape(1,-1)), axis=0) if curr_dists is not None else curr_curr_dists.reshape(1,-1)
err /= len(aug_list)
avg_dists = np.mean(curr_dists, axis=0)
aug_string = "--".join(sub_aug_strings)
data_out = ",".join([aug_string, str(err)] + [str(x) for x in avg_dists])
out.append(data_out)
with open(out_file, 'w') as f:
f.write(",,".join([data_dir, baseline_dir, str(precision), str(num_corr)]))
f.write("\n")
f.write("\n".join(out))