in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _hstu_attn_bwd_one_block( # noqa C901
start_m,
offs_n,
offs_m,
q_ptrs_trans,
dq_ptrs_trans,
mask_n,
do_ptrs,
dk,
dv,
k,
v,
pos_offs_n,
seq_len,
n_targets,
max_ids,
contextual_seq_len,
max_attn_len,
LOCK,
stride_qm,
stride_dom,
stride_dqm,
alpha,
MAX_SEQ_LEN,
CAUSAL: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
ATOMIC_ADD: tl.constexpr,