inline __device__ Gmem_tile_qkv()

in kernels/fmha/gmem_tile.h [69:96]


    inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
                                    const uint32_t head_stride_in_elts, const int headdim,
                                    const BInfo &binfo, const int tidx, bool use_seqlen_q)
        : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
        , actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
        , ptr(reinterpret_cast<char *>(ptr_))
        , tidx_(tidx)
        , col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {

        // Compute the position in the sequence (within the CTA for the moment).
        int row = tidx / THREADS_PER_ROW;
        // Compute the position of the thread in the row.
        int col = tidx % THREADS_PER_ROW;

        // Store the row as we need it to disable the loads.
        // TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
        // row_ = row;

        // The row offset in the batched GEMM. For each seq element, we store QKV in that order.
        // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
        uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
        // Add the block index.
        // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
        row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);

        // Assemble the final pointer.
        ptr += row_offset + col * BYTES_PER_LDG;
    }