def get_best_configs()

in deep_gemm/jit_kernels/gemm.py [0:0]


def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
                     is_grouped_contiguous: bool = False, is_grouped_masked: bool = False) -> \
        Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
    if not is_grouped_contiguous:
        block_ms = (64, 128, 256)
    else:
        block_ms = (get_m_alignment_for_contiguous_layout(), )
    block_ns = tuple(range(16, 129, 8)) + (144, 160, )

    fix_wave_saturate = lambda x: num_sms if x == 0 else x
    get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
    get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)

    # Decide block sizes by waves
    best_block_m, best_block_n = None, None
    for block_m in block_ms:
        # NOTES: the block sizes can not be too large, so at least one dim less than 128
        for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
            success = False
            num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
            if best_block_m is None or best_block_n is None:
                success = True
            elif num_waves < best_num_waves:
                success = True
            elif num_waves == best_num_waves:
                # Check last wave utilization
                util = get_last_wave_util(block_m, block_n)
                best_util = get_last_wave_util(best_block_m, best_block_n)
                success = util > best_util
                if util == best_util:
                    # Case 1: same `block_m`, smaller `block_n` (wasted)
                    success |= block_m == best_block_m and block_n < best_block_n
                    # Case 2: same `block_n`, smaller `block_m` (wasted)
                    success |= block_n == best_block_n and block_m < best_block_m
                    # Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
                    success |= block_m != best_block_m and block_n > best_block_n
            best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
    assert best_block_m is not None and best_block_n is not None

    # Always pick the longest one
    # NOTES: for double B scales, the best number of stages may be reduced
    best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
    stage_candidates = tuple(filter(lambda s: s <= k // 128, (8, 7, 6, 5, 4, 3)))
    if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
        # Unrolling both stages and `num_former_iters` will cause large code size
        stage_candidates = (4, 3)
    for num_stages in stage_candidates:
        best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n)
        if best_smem_config[0] <= sm90_capacity:
            best_num_stages = num_stages
            break
    assert best_smem_config is not None
    assert best_num_stages is not None

    # Decide the number of TMA multicast and whether broadcast on A
    best_tma_multicast_config = (1, True)

    # Try to multicast on the larger block side first
    # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
    is_multicast_legal = {
        'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
        'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
    }
    for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
        if m >= 512 and is_multicast_legal[i]:
            best_tma_multicast_config = (2, i == 'A')
            break

    # Recompute the minimal number of SMs required
    # NOTES: less L2 cache usage and less GPU frequency drop
    num_waves = get_num_waves(best_block_m, best_block_n)
    num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
    num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
    assert num_min_sms <= num_sms

    return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config