inline __device__ void device_1xN_loop()

in candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h [664:702]


inline __device__ void device_1xN_loop(const Params &params) {

    // The block index for the batch.
    const int bidb = blockIdx.x;
    // The block index for the head.
    const int bidh = blockIdx.y;
    // The block index.
    const int bidx = gridDim.x * bidh + bidb;
    // The thread index.
    const int tidx = threadIdx.x;

    // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
    // them to have the same number of threads or have to traverse the attention matrix
    // in the same order.
    // In the Philox RNG, we use the offset to store the batch, head, and the lane id
    // (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
    // the attention matrix. This way, as long as we have the batch, head, and the location of
    // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
//    auto seeds = at::cuda::philox::unpack(params.philox_args);
//    if (bidx == 0 && tidx == 0) {
//        params.rng_state[0] = std::get<0>(seeds);
//        params.rng_state[1] = std::get<1>(seeds);
//    }
    Philox ph(0, 0, 0 + (bidb * params.h + bidh) * 32 + tidx % 32);
    constexpr int M = Kernel_traits::Cta_tile_p::M;
    const int STEPS = (params.seqlen_q + M - 1) / M;

    constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
    if (params.seqlen_k == blocksize_c) {
        fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
    } else {
        const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
        fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
        for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
            fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
        }
        fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
    }
}