in experiments/sample_datasets.py [0:0]
def sample_matched_corruptions(err_table, baseline_err, precision, num):
'''
Iterates over all 'num'-sized combinations of corruptions and selects
a set of severities that has error within 'precision' of the baseline
error. If multiple sets of severities fall within the precision
window, it picks one at random. If none do, it skips that combination
of corruptions.
The runtime is O((num_corruptions * num_severity_sets)^num), though
in practice the algorithm below is usually
O(num_corruptions^num * num_severity_sets^(num-1)).
Inputs:
err_table: 2d numpy array of shape [corruptions, severity_sets]
listing the average error for each set.
baseline_err: float giving the target error to match
precision: float giving the percentage variation from the baseline
error allowed for an accepted severity set
num: int listing the number of corruptions to combine
Output:
A list of sampled datasets, where each sampled dataset is a list
of 'num' 2-tuples (corruption_index, severity_set_index).
'''
count = 0
total = comb(err_table.shape[0], num, exact=True)
chosen_augs = []
for idxs in combinations(range(err_table.shape[0]), num):
all_augs = []
count += 1
if count % 1000 == 0:
print("On iteration {}/{}".format(count, total))
# Loop over severities for all chosen corruptions except for the
# last two. Since the severity sets are ordered by average error,
# we can work from the outside in to typically save one factor of
# 'num' in calculation time.
for sev_idxs in product(*[range(err_table.shape[1]) for i in range(num-2)]):
target = baseline_err * num
err_sum = 0.0
for i in range(num-2):
err_sum += err_table[idxs[i], sev_idxs[i]]
stack = [(0, err_table.shape[1]-1)] # Start on the two ends
seen = set()
while stack:
i, j = stack.pop()
if (i,j) in seen or i >= err_table.shape[1] or j < 0:
continue
seen.add((i,j))
final_err_sum = err_sum + err_table[idxs[-2],i] + err_table[idxs[-1],j]
if abs((target-final_err_sum)/target) < precision:
curr = [(idxs[k], sev_idxs[k]) for k in range(num-2)] + [(idxs[-2],i), (idxs[-1],j)]
all_augs.append(curr)
stack.append([i+1, j])
stack.append([i, j-1])
elif (target-final_err_sum)/target >= precision:
stack.append([i+1, j])
else:
stack.append([i, j-1])
if all_augs:
idx_choice = np.random.randint(low=0, high=len(all_augs))
chosen_augs.append(all_augs[idx_choice])
return chosen_augs