in experiments/momentum_regimes.py [0:0]
def filter_sketch_sizes(sketch_sizes, sketch_size_formulas):
"""Filter duplicate sketch sizes and merge the formulas"""
filtered_sketch_sizes = sketch_sizes.copy()
filtered_sketch_size_formulas = sketch_size_formulas.copy()
if len(sketch_sizes) not in [1, 2, 3]:
raise ValueError("Works only with 2 or 3 sketch sizes")
elif len(sketch_sizes) == 2:
if len(np.unique(filtered_sketch_sizes)) < len(filtered_sketch_sizes):
if filtered_sketch_sizes[0] == filtered_sketch_sizes[1]:
filtered_sketch_size_formulas[1] = (
filtered_sketch_size_formulas[0]
+ " = "
+ filtered_sketch_size_formulas[1]
)
del filtered_sketch_sizes[0]
del filtered_sketch_size_formulas[0]
elif len(sketch_sizes) == 3:
if len(np.unique(filtered_sketch_sizes)) < len(filtered_sketch_sizes):
idx_to_del = []
if (
filtered_sketch_sizes[0] == filtered_sketch_sizes[1]
and filtered_sketch_sizes[1] == filtered_sketch_sizes[2]
):
idx_to_del = [0, 1]
filtered_sketch_size_formulas[2] = (
filtered_sketch_size_formulas[0]
+ " = "
+ filtered_sketch_size_formulas[1]
+ " = "
+ filtered_sketch_size_formulas[2]
)
elif filtered_sketch_sizes[0] == filtered_sketch_sizes[1]:
idx_to_del = [0]
filtered_sketch_size_formulas[1] = (
filtered_sketch_size_formulas[0]
+ " = "
+ filtered_sketch_size_formulas[1]
)
elif filtered_sketch_sizes[1] == filtered_sketch_sizes[2]:
idx_to_del = [1]
filtered_sketch_size_formulas[2] = (
filtered_sketch_size_formulas[1]
+ " = "
+ filtered_sketch_size_formulas[2]
)
elif filtered_sketch_sizes[0] == filtered_sketch_sizes[2]:
idx_to_del = [0]
filtered_sketch_size_formulas[2] = (
filtered_sketch_size_formulas[0]
+ " = "
+ filtered_sketch_size_formulas[2]
)
for index in sorted(idx_to_del, reverse=True):
del filtered_sketch_sizes[index]
del filtered_sketch_size_formulas[index]
return filtered_sketch_sizes, filtered_sketch_size_formulas