in experimental/ragged_inference/triton_v2_qk_dotprod.py [0:0]
def get_all_configs():
return [
# basic configs for compute-bound matmuls
triton.Config(
{"BLOCK_M": 128, "BLOCK_K": 256, "BLOCK_D": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_K": 128, "BLOCK_D": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_K": 64, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_K": 256, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_K": 128, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_K": 64, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_K": 128, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_K": 32, "BLOCK_D": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_K": 32, "BLOCK_D": 32},
num_stages=5,
num_warps=2,
),
] + get_configs_io_bound()