in deep_gemm/jit_kernels/gemm.py [0:0]
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, 256)
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + (144, 160, )
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
for block_m in block_ms:
# NOTES: the block sizes can not be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util
if util == best_util:
# Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n
# Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
success |= block_m != best_block_m and block_n > best_block_n
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
stage_candidates = tuple(filter(lambda s: s <= k // 128, (8, 7, 6, 5, 4, 3)))
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = (4, 3)
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
assert best_smem_config is not None
assert best_num_stages is not None
# Decide the number of TMA multicast and whether broadcast on A
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config