in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _hstu_attn_bwd_one_col_block( # noqa C901
start_n,
seq_len,
n_targets,
contextual_seq_len,
max_attn_len,
Q,
K,
V,
DOut,
DQ,
DK,
DV,
LOCK,
stride_qm,
stride_kn,
stride_vn,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
alpha,
MAX_SEQ_LEN,
CAUSAL: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_D_Q: tl.constexpr,
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
UNROLL: tl.constexpr,
ATOMIC_ADD: tl.constexpr,