in src/nanotron/nn/ring_attention_lucidrain.py [0:0]
def _bwd_kernel_one_col_block(
start_n,
Q,
K,
V,
Bias,
DO,
DQ,
DK,
DV,
LSE,
D,
softmax_scale,
stride_qm,
stride_kn,
stride_vn,
stride_bm,
stride_dom,
stride_dqm,
stride_dkn,
stride_dvn,
seqlen_q,
seqlen_k,
headdim,
ATOMIC_ADD: tl.constexpr,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
CAUSAL_MASK_DIAGONAL: tl.constexpr,
SOFTCLAMP_QK_SIM: tl.constexpr,
SOFTCLAMP_VALUE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,