inline __device__ Softmax_base()

in kernels/fmha/softmax.h [168:196]


    inline __device__ Softmax_base(const Params &params, void *smem, int tidx)
        :  // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
          smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {

        // Move to the 1st mask loaded by the thread+ tidx;
        // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);

        // Extract the position in the warp.
        int warp = tidx / Cta_tile::THREADS_PER_WARP;
        int lane = tidx % Cta_tile::THREADS_PER_WARP;

        // Decompose the warp index into M and N.
        int warp_m = warp % Cta_tile::WARPS_M;
        int warp_n = warp / Cta_tile::WARPS_M;

        // Decompose the warp-n index into group/position-inside-the-group.
        int warp_g = warp_n / ELEMENTS_PER_ROW;
        int warp_i = warp_n % ELEMENTS_PER_ROW;

        // The location written by the threads.
        int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
        int write_col = warp_i;

        // Assemble the write pointer.
        smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];

        // Assemble the read pointer.
        smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
    }