in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _get_bw_configs() -> List[triton.Config]:
if torch.version.hip:
configs = []
for BLOCK_M in [32, 64]:
for BLOCK_N in [32, 64]:
for num_stages in [1, 2]:
for num_warps in [4, 8]:
for matrix_instr_nonkdim in [16, 32]:
for waves_per_eu in [0, 2, 4]:
for sp in [True, False]:
configs.append(
triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"matrix_instr_nonkdim": matrix_instr_nonkdim, # NOQA
"waves_per_eu": waves_per_eu,
"SEQUENCE_PARALLEL": sp,
"UNROLL": 1,
},
num_stages=num_stages,
num_warps=num_warps,
pre_hook=_bwd_pre_hook,
)
)
return configs
configs = [
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1},
num_stages=3,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 2},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 4},
num_stages=2,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=2,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=1,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=2,
num_warps=4,
pre_hook=_bwd_pre_hook,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True, "UNROLL": 1},
num_stages=3,
num_warps=8,
pre_hook=_bwd_pre_hook,
),
]
return configs