in src/nanotron/nn/ring_attention_lucidrain.py [0:0]
def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
M,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
HAS_BIAS: tl.constexpr,
IS_CAUSAL: tl.constexpr,
CAUSAL_MASK_DIAGONAL: tl.constexpr,
LOAD_ACCUMULATED: tl.constexpr,
RETURN_NORMALIZED_OUTPUT: tl.constexpr,
SOFTCLAMP_QK_SIM: tl.constexpr,
SOFTCLAMP_VALUE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,