in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _get_fw_configs() -> List[triton.Config]: # noqa: C901
configs = []
if torch.version.hip:
for BLOCK_M in [32, 64, 128]:
for BLOCK_N in [32, 64]:
for num_stages in [1, 2]:
for num_warps in [4, 8]:
for matrix_instr_nonkdim in [16, 32]:
configs.append(
triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"matrix_instr_nonkdim": matrix_instr_nonkdim,
"waves_per_eu": 0,
"kpack": 2,
},
num_stages=num_stages,
num_warps=num_warps,
)
)
else:
configs = [
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32},
num_stages=2,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32},
num_stages=2,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64},
num_stages=2,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=2,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 32},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128},
num_stages=2,
num_warps=8,
),
]
return configs