def sample_matched_corruptions()

in experiments/sample_datasets.py [0:0]


def sample_matched_corruptions(err_table, baseline_err, precision, num):
    '''
    Iterates over all 'num'-sized combinations of corruptions and selects
    a set of severities that has error within 'precision' of the baseline
    error.  If multiple sets of severities fall within the precision
    window, it picks one at random.  If none do, it skips that combination
    of corruptions.

    The runtime is O((num_corruptions * num_severity_sets)^num), though
    in practice the algorithm below is usually
    O(num_corruptions^num * num_severity_sets^(num-1)).

    Inputs:
    err_table: 2d numpy array of shape [corruptions, severity_sets] 
      listing the average error for each set.
    baseline_err: float giving the target error to match
    precision: float giving the percentage variation from the baseline
      error allowed for an accepted severity set
    num: int listing the number of corruptions to combine

    Output:
    A list of sampled datasets, where each sampled dataset is a list
      of 'num' 2-tuples (corruption_index, severity_set_index).
    '''
    count = 0
    total = comb(err_table.shape[0], num, exact=True)
    chosen_augs = []
    for idxs in combinations(range(err_table.shape[0]), num):
        all_augs = []
        count += 1
        if count % 1000 == 0:
            print("On iteration {}/{}".format(count, total))
        # Loop over severities for all chosen corruptions except for the
        # last two.  Since the severity sets are ordered by average error,
        # we can work from the outside in to typically save one factor of
        # 'num' in calculation time.
        for sev_idxs in product(*[range(err_table.shape[1]) for i in range(num-2)]):
            target = baseline_err * num
            err_sum = 0.0
            for i in range(num-2):
                err_sum += err_table[idxs[i], sev_idxs[i]]
            stack = [(0, err_table.shape[1]-1)] # Start on the two ends
            seen = set()
            while stack:
                i, j = stack.pop()
                if (i,j) in seen or i >= err_table.shape[1] or j < 0:
                    continue
                seen.add((i,j))
                final_err_sum = err_sum + err_table[idxs[-2],i] + err_table[idxs[-1],j]
                if abs((target-final_err_sum)/target) < precision:

                    curr = [(idxs[k], sev_idxs[k]) for k in range(num-2)] + [(idxs[-2],i), (idxs[-1],j)]
                    all_augs.append(curr)

                    stack.append([i+1, j])
                    stack.append([i, j-1])
                elif (target-final_err_sum)/target >= precision:
                    stack.append([i+1, j])
                else:
                    stack.append([i, j-1])
        if all_augs:
            idx_choice = np.random.randint(low=0, high=len(all_augs))
            chosen_augs.append(all_augs[idx_choice])
    return chosen_augs