def filter_sketch_sizes()

in experiments/regularization_on_momentum.py [0:0]


def filter_sketch_sizes(sketch_sizes, sketch_size_formulas):
    """Filter duplicate sketch sizes and merge the formulas"""
    filtered_sketch_sizes = sketch_sizes.copy()
    filtered_sketch_size_formulas = sketch_size_formulas.copy()
    if len(sketch_sizes) not in [1, 2, 3]:
        raise ValueError("Works only with 2 or 3 sketch sizes")
    elif len(sketch_sizes) == 2:
        if len(np.unique(filtered_sketch_sizes)) < len(filtered_sketch_sizes):
            if filtered_sketch_sizes[0] == filtered_sketch_sizes[1]:
                filtered_sketch_size_formulas[1] = (
                    filtered_sketch_size_formulas[0]
                    + " = "
                    + filtered_sketch_size_formulas[1]
                )
                del filtered_sketch_sizes[0]
                del filtered_sketch_size_formulas[0]
    elif len(sketch_sizes) == 3:
        if len(np.unique(filtered_sketch_sizes)) < len(filtered_sketch_sizes):
            idx_to_del = []
            if (
                filtered_sketch_sizes[0] == filtered_sketch_sizes[1]
                and filtered_sketch_sizes[1] == filtered_sketch_sizes[2]
            ):
                idx_to_del = [0, 1]
                filtered_sketch_size_formulas[2] = (
                    filtered_sketch_size_formulas[0]
                    + " = "
                    + filtered_sketch_size_formulas[1]
                    + " = "
                    + filtered_sketch_size_formulas[2]
                )
            elif filtered_sketch_sizes[0] == filtered_sketch_sizes[1]:
                idx_to_del = [0]
                filtered_sketch_size_formulas[1] = (
                    filtered_sketch_size_formulas[0]
                    + " = "
                    + filtered_sketch_size_formulas[1]
                )
            elif filtered_sketch_sizes[1] == filtered_sketch_sizes[2]:
                idx_to_del = [1]
                filtered_sketch_size_formulas[2] = (
                    filtered_sketch_size_formulas[1]
                    + " = "
                    + filtered_sketch_size_formulas[2]
                )
            elif filtered_sketch_sizes[0] == filtered_sketch_sizes[2]:
                idx_to_del = [0]
                filtered_sketch_size_formulas[2] = (
                    filtered_sketch_size_formulas[0]
                    + " = "
                    + filtered_sketch_size_formulas[2]
                )

            for index in sorted(idx_to_del, reverse=True):
                del filtered_sketch_sizes[index]
                del filtered_sketch_size_formulas[index]

    return filtered_sketch_sizes, filtered_sketch_size_formulas