in tzrec/ops/triton/triton_addmm.py [0:0]
def get_mm_configs() -> List[triton.Config]:
if torch.version.hip:
if ENABLE_FULL_TURNING_SPACE:
block_m_range = [32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64]
group_m_range = [4, 8]
matrix_instr_nonkdim_range = [16]
waves_per_eu_range = [0]
kpack_range = [1, 2]
num_warps_range = [4, 8]
num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0]
else:
block_m_range = [256]
block_n_range = [256]
block_k_range = [32]
group_m_range = [8]
matrix_instr_nonkdim_range = [16]
waves_per_eu_range = [0]
kpack_range = [2]
num_warps_range = [8]
num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0]
return [
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_M": group_m,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"waves_per_eu": waves_per_eu,
"kpack": kpack,
},
num_stages=num_stages,
num_warps=num_warps,
)
for block_m in block_m_range
for block_n in block_n_range
for block_k in block_k_range
for group_m in group_m_range
for matrix_instr_nonkdim in matrix_instr_nonkdim_range
for waves_per_eu in waves_per_eu_range
for kpack in kpack_range
for num_stages in num_stage_range
for num_warps in num_warps_range
]
else:
return [
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 64,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 64,
"GROUP_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 256,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 128,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 32,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 32,
"BLOCK_K": 32,
"GROUP_M": 8,
},
num_stages=5,
num_warps=2,
),
]