in bench_cluster/create_configs.py [0:0]
def find_combinations_within_global_batch_size_range(dp, seq_len, min_global_batch_size, max_global_batch_size, step, bapr_max: None):
def round_to_next_multiple_of(multiple, pair_list):
round_up = lambda n: math.ceil(n / multiple) * multiple
res = []
for a, b in pair_list:
new_a = round_up(a) if a != 1 else a
new_b = round_up(b) if b != 1 else b
res.append((new_a, new_b))
return res
combinations = []
for i in range(min_global_batch_size, max_global_batch_size + 1, step):
remaining_global_batch_size = i // (dp * seq_len)
all_pairs = [(a, b) for a, b in itertools.product(range(1, remaining_global_batch_size + 1), repeat=2) if a * b == remaining_global_batch_size]
all_pairs = round_to_next_multiple_of(multiple=2, pair_list=all_pairs)
for bapr, mbs in all_pairs:
if bapr_max is not None and bapr > bapr_max:
continue
current_global_batch_size = dp * seq_len * bapr * mbs
# Include as well the case where the current_global_batch size similar to max_global_batch size
if current_global_batch_size >= min_global_batch_size and current_global_batch_size <= max_global_batch_size + step:
combinations.append((bapr, mbs))
return combinations