inline __device__ void device_1xN_()

in candle-flash-attn-v1/kernels/fmha_fprop_kernel_1xN.h [200:659]


inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {

#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
    using elem_type = typename Kernel_traits::elem_type;
#else
    constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
    assert(is_fp16_type);
    using elem_type = __half;
#endif

    // The description of the CTA tile for the 1st batched GEMM.
    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
    // The description of the CTA tile for the 2nd batched GEMM.
    using Cta_tile_o = typename Kernel_traits::Cta_tile_o;

    // The MMA tile for the 1st GEMM.
    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
    // The MMA tile for the 2nd GEMM.
    using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;

    // The global memory tile to load Q.
    using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;

    // The global memory tile to load K.
    using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;

    // The global memory tile to load V.
    using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
    // The shared memory tile to swizzle V.
    using Smem_tile_v = typename Kernel_traits::Smem_tile_v;

    // The global memory tile to store O.
    using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
    using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
    // The shared memory tile to swizzle O.
    using Smem_tile_o = typename Kernel_traits::Smem_tile_o;

    using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;

    using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;

    using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;

    using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;

    using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;

    // Shared memory.
    extern __shared__ char smem_[];

    // The thread index.
    const int tidx = threadIdx.x;

    // How many steps to jump per iteration, which is the same as params.num_splits.
    const int step_stride = gridDim.z;

    const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
    // if( binfo.stop_early() ) return;
    if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;

    Gemm1 gemm_q_k(smem_, tidx);
    // Allocate the global memory tile loader for Q.
    Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
                       params.d, binfo, tidx, true);
    // Allocate the global memory tile loader for O.
    Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
                       params.d, binfo, tidx);
    Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts,
                               params.o_tmp_head_stride_in_elts, params.d, binfo, tidx);
    // Allocate the global memory tile loader for S.
    Gmem_tile_s gmem_s(params, binfo, tidx);
    Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);

    // Wind gmem tiles to the correct position.
    static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
    int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
    // We want begin to be a multiple of gridDim.z
    // This is because the row indices processed by each threadblock must align between the
    // loop steps, otherwise we have a dependency between the blocks.
    // For example, threadblock with blockIdx.z == 1 must process row indices that are
    // k * gridDim.z + 1 for integer k.
    const int begin_mod_z = begin % gridDim.z;
    begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
    // Otherwise we'd be reading out-of-bound memory before the loop
    if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return;
    const int steps_og = steps;
    steps -= begin;
    gmem_q.move(begin + blockIdx.z);
    gmem_o.move(begin + blockIdx.z);
    gmem_o_tmp.move(begin + blockIdx.z);
    if (Return_softmax) {
        gmem_s.move(begin + blockIdx.z);
    }
    gmem_softmax_lse.move(begin + blockIdx.z);
    // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
    //     printf("begin = %d, steps = %d\n", begin, steps);
    // }

    fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);

    // Allocate the global memory tile loader for K.
    Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
                       params.d, binfo, tidx, false);
    // Allocate the global memory tile loader for V.
    Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
                       params.d, binfo, tidx, false);
    // The base pointer of smem_v;
    char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
    
    // Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
    Smem_tile_v smem_v(smem_v_, tidx);

    // Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
    Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);

    if (!Is_first) {
        gmem_k.move(loop_step_idx);
        gmem_v.move(loop_step_idx);
        if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
    }

    // Trigger the loads for K.
    gmem_k.load();
    // Trigger the loads for Q.
    gmem_q.load();
    // Trigger the loads for V.
    gmem_v.load();

    if (!Is_first) { __syncthreads(); }

    float p_prev_lse[Mma_tile_p::MMAS_M * 2];
    if (!Is_first) {
        gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
    }

    // Commit the data for Q and V to shared memory.
    gmem_q.commit(gemm_q_k.smem_q);
    gmem_v.commit(smem_v);

    // const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
    // #pragma unroll
    // for(int it=0;it < Gmem_tile_k::LDGS;it++){
    //     gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
    // }

    // Commit the data for K to shared memory.
    if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
        gmem_k.commit(gemm_q_k.smem_k);
    }

    __syncthreads();

    // Load the fragments for Q.
    gemm_q_k.load_q();

    // Load the fragments for V. We keep the data in registers during the entire kernel.
    typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
    #pragma unroll
    for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
        smem_v.load(frag_v[ki], ki);
    }

    // Commit the data for V to shared memory if it has not been done already.
    if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
        // Make sure we are done loading the fragments for K.
        __syncthreads();

        // Commit the data to shared memory for V.
        gmem_k.commit(gemm_q_k.smem_k);

        // Make sure the data is in shared memory.
        __syncthreads();
    }

    // Load the fragments for K. 
    gemm_q_k.load_k();

    // Create the object to do the softmax.
    Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);

    Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);

    // Load over the entire sequence length.
    for (int l = blockIdx.z; l < steps; l += step_stride) {
        // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
        //     printf("l = %d\n", l);
        // }
        if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;

        // Declare the accumulators for the 1st gemm.
        fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);

        // Do this part of P = Q * K^T.
        gemm_q_k(acc_p);

        // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
        //     printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
        // }

        uint4 out[Gmem_tile_o::STGS_PER_LOOP];
        if (!Is_first) { gmem_o_tmp.load(out, 0); }

        // Trigger the load for the next Q values.
        if (l + step_stride < steps) {
            gemm_q_k.smem_q.move_to_next_write_buffer();
            gmem_q.move(step_stride);
            gmem_q.load();
        }

        // Load the mask for that iteration.
        mask.load(begin + l);

        // Convert from the accumulator type to FP32 for Softmax.
        softmax.unpack_noscale(acc_p);

        // Apply the mask.
        softmax.apply_mask(mask);

        if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
            // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
            __syncthreads();
        }
        // if (!Is_first) {
        //     if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0))  {
        //         printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
        //     }
        // }
        // Compute the max.
        float p_max[Mma_tile_p::MMAS_M * 2];
        if (!Is_first) {
            smem_softmax_lse.store_pair(p_prev_lse);
            // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
            for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
        }

        // Trigger the load for the next LSE values.
        if (l + step_stride < steps) {
            if (!Is_first) {
                gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
                                           step_stride);
            }
        }

        softmax.template reduce_max</*zero_init=*/Is_first>(p_max);

        // if ((threadIdx.x == 0) && (l == 38)) {
        //     printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
        // }

        // if (!Is_first) {
        //     if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
        //         printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
        //     }
        // }

        // Compute the exponential value.
        // softmax.apply_exp(p_max);
        softmax.scale_apply_exp(p_max, params.scale_bmm1f);

        // if (!Is_first) {
        //     if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
        //         printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
        //     }
        // }

        // Compute the sum.
        float p_sum[Mma_tile_p::MMAS_M * 2];
        // if (!Is_first) {
        //     int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
        //     int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
        //     for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
        //         p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
        //     }
        // }
        // softmax.reduce_sum(p_sum);
        softmax.reduce_sum_before_sync_(p_sum);
        // softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);

        // float p_sum_log[Mma_tile_p::MMAS_M * 2];
        // for (int mi = 0; mi  < Mma_tile_p::MMAS_M * 2; ++mi) {
        //     float sum = p_sum[mi];
        //     // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
        //     constexpr float kLog2e = M_LOG2E;
        //     p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
        // }
        // // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
        // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
        // gmem_softmax_lse.move();

        // // Finalize softmax on the accumulators of P^T.
        // softmax.scale(p_sum);

        constexpr bool encode_dropout_in_sign_bit = Return_softmax;
        if (Is_dropout) {
            // softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
            // softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
            // softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
            unsigned int warp_idx = threadIdx.x / 32;
            // TODO: this should change after we rearrange the warps (e.g. cutlass branch)
            unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
            // We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
            // differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
            // to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
            unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
            softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
        }

        using Frag_p = fmha::Fragment_a<fmha::Row>;
        Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
        static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
        static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
        softmax.template pack<elem_type>(frag_p);
        if (Return_softmax) {
            gmem_s.store(frag_p, mask);
            gmem_s.move(step_stride);
        }

        // Commit the values for Q into shared memory.
        if (l + step_stride < steps) {
            gmem_q.commit(gemm_q_k.smem_q);
        }

        if (Is_dropout && encode_dropout_in_sign_bit) {
            #pragma unroll
            for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
                #pragma unroll
                for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
                    frag_p[ki][mi].template hrelu_<elem_type>();
                }
            }
        }

        // Declare the accumulators for the 2nd gemm.
        fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
        fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);

        // Do this part of O = P^T * V^T.
        #pragma unroll
        for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
            fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
            // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
            //     float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
            //     float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
            //     printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
            // }
        }

        // if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
        //     printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
        // }

        // The mapping from tidx to rows changes between the softmax and the
        // O-reduction. So we recalculate the max.
        float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
        int rows[Gmem_tile_o::STGS_PER_LOOP];
        for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
            rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
        }
        softmax.reduce_max_after_sync_(p_max_o, rows);
        static_assert(Mma_tile_o::MMAS_M == 1);
        for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
            p_max_o[jj][0] *= params.scale_bmm1f;
        }
        float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
        if (!Is_first) {
            smem_softmax_lse.load(p_prev_scale_o, rows);
        }
        // if (!Is_first) {
        //     if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
        //         printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
        //     }
        // }

        static_assert(Gmem_tile_o::LOOPS == 1);

        // Swizzle the elements and do the final reduction.
        smem_o.store(acc_o, 0);

        // Make sure the data is in shared memory.
        __syncthreads();

        static_assert(Mma_tile_o::MMAS_M == 1);
        float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
        softmax.reduce_sum_after_sync_(p_sum_o, rows);
        if (!Is_first) {
            for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
                p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
                p_sum_o[jj][0] += p_prev_scale_o[jj];
            }
        }

        float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
        #pragma unroll
        for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
            float sum = p_sum_o[jj][0];
            p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
            // if (sum == 0.f || sum != sum) {
            //     printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
            // }
            // if (Is_first) {
            //     if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0))  {
            //         printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
            //     }
            // }
            if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) {
                gmem_softmax_lse.store_row(
                    reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
            }
        }
        gmem_softmax_lse.move(step_stride);

        // Load from shared memory.
        if (!Is_first) {
            for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
                out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
            }
        }
        smem_o.template load</*zero_init=*/Is_first>(out);

        const bool is_final_write =
            Is_last
            || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
            || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
        #pragma unroll
        for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
            float sum = p_sum_o[jj][0];
            float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
            if (Is_dropout && is_final_write) {
                inv_sum *= params.rp_dropout;
            }
            out[jj] = fmha::fmul4(out[jj], inv_sum);
        }

        // if (Is_dropout && Is_last) {
        //     for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
        //         out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
        //     }
        // }

        // Output the values.
        if (is_final_write) {
            gmem_o.template store<elem_type>(out, 0);
            gmem_o.move(step_stride);
        } else {
            gmem_o_tmp.store(out, 0);
        }

        // Move to the next part of the output.
        if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
        gemm_q_k.reload_k();

        // Make sure we are reading from the correct buffer.
        gemm_q_k.smem_q.move_to_next_read_buffer();
        // Trigger the load from shared memory for the next series of Q values.
        if (l + step_stride < steps) {
            gemm_q_k.reload_q();
        }
    }  // Outer loop over the sequence length.
}