def main()

in experiments/sample_datasets.py [0:0]


def main():
    args = parser.parse_args()
    data_dir = args.data_dir
    baseline_dir = args.baseline_dir
    precision = args.precision
    num_corr = args.num
    out_file = args.out
    log_name = args.log_name
    target_error = args.target
    baseline_exclusions = ['saturate', 'spatter', 'gaussian_blur', 'speckle_noise']
    corr_exclusions = []

    print("Loading data...")
    data_dirs = data_dir.split(",")
    baseline_dirs = baseline_dir.split(",")
    corr_errs, corr_features = get_data(data_dirs, corr_exclusions, log_file=log_name)
    baseline_errs, baseline_features = get_data(baseline_dirs, exclusions=baseline_exclusions, log_file=log_name)

    baseline_table = build_baseline_features(baseline_features)
    avg_spread = get_average_spread(baseline_errs)



    corr_sets = build_sets(corr_errs, avg_spread)
    corrs, ordered_sev_list, err_table, feat_table = build_corr_tables(corr_sets, corr_errs, corr_features)
    dists = build_distance_table(baseline_table, feat_table)
    
    chosen = sample_matched_corruptions(err_table, target_error, precision, num_corr)
    out = []
    for aug_list in chosen:
        sub_aug_strings = []
        err = 0.0
        curr_dists = None
        for a in aug_list:
            corr = corrs[a[0]]
            sevs = ordered_sev_list[corr][a[1]]
            sub_aug_strings.append("--".join(["{}-{}".format(corr,s) for s in sevs]))
            err += err_table[a[0], a[1]]
            curr_curr_dists = dists[a[0], a[1]]
            curr_dists = np.concatenate((curr_dists, curr_curr_dists.reshape(1,-1)), axis=0) if curr_dists is not None else curr_curr_dists.reshape(1,-1)
        err /= len(aug_list)
        avg_dists = np.mean(curr_dists, axis=0)
        aug_string = "--".join(sub_aug_strings)
        data_out = ",".join([aug_string, str(err)] + [str(x) for x in avg_dists])
        out.append(data_out)

    with open(out_file, 'w') as f:
        f.write(",,".join([data_dir, baseline_dir, str(precision), str(num_corr)]))
        f.write("\n")
        f.write("\n".join(out))