in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _hstu_attn_fwd_compute( # noqa C901
Q,
K,
V,
seq_offsets,
num_targets,
Out,
stride_qm,
stride_qh,
stride_kn,
stride_kh,
stride_vn,
stride_vh,
stride_om,
stride_oh,
alpha,
MAX_SEQ_LEN,
DeltaSize,
contextual_seq_len,
max_attn_len,
off_z,
off_h,
pid,
CAUSAL: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
IS_DELTA_Q: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_D_Q: tl.constexpr,
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,