in tzrec/ops/triton/triton_hstu_attention.py [0:0]
def _hstu_attn_fwd_persistent( # noqa C901
Q,
K,
V,
sort_by_length_indices,
seq_offsets,
num_targets,
Out,
stride_qm,
stride_qh,
stride_kn,
stride_kh,
stride_vn,
stride_vh,
stride_om,
stride_oh,
alpha,
Z,
AUTOTUNE_Z,
H,
MAX_SEQ_LEN,
AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key
DimQ,
DimV,
DeltaSize,
contextual_seq_len,
max_attn_len,
CAUSAL: tl.constexpr,
HAS_MULTIPLE_TARGETS: tl.constexpr,
IS_DELTA_Q: tl.constexpr,
ALLOW_TF32: tl.constexpr,
BLOCK_D_Q: tl.constexpr,
BLOCK_D_V: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr,
HAS_MAX_ATTN_LEN: tl.constexpr,
HAS_SORT_BY_LENGTH_INDICES: tl.constexpr,