def find_combinations_within_global_batch_size_range()

in bench_cluster/create_configs.py [0:0]


def find_combinations_within_global_batch_size_range(dp, seq_len, min_global_batch_size, max_global_batch_size, step, bapr_max: None):
    def round_to_next_multiple_of(multiple, pair_list):        
        round_up = lambda n: math.ceil(n / multiple) * multiple
        res = []
        for a, b in pair_list:
            new_a = round_up(a) if a != 1 else a
            new_b = round_up(b) if b != 1 else b
            res.append((new_a, new_b))
        return res

    combinations = []
    
    for i in range(min_global_batch_size, max_global_batch_size + 1, step):
        remaining_global_batch_size = i // (dp * seq_len)
        all_pairs = [(a, b) for a, b in itertools.product(range(1, remaining_global_batch_size + 1), repeat=2) if a * b == remaining_global_batch_size]
        all_pairs = round_to_next_multiple_of(multiple=2, pair_list=all_pairs)
        
        for bapr, mbs in all_pairs:
            if bapr_max is not None and bapr > bapr_max:
                continue
            current_global_batch_size = dp * seq_len * bapr * mbs
            # Include as well the case where the current_global_batch size similar to max_global_batch size
            if current_global_batch_size >= min_global_batch_size and current_global_batch_size <= max_global_batch_size + step:
                combinations.append((bapr, mbs))

    return combinations