in maga_transformer/cpp/cutlass/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h [331:627]
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum)
{
//
// Prologue
//
TransformBAfterLDS lds_converter;
// NOTE - switch to ldg.sts
// Issue this first, so cp.async.commit_group will commit this load as well.
// Note: we do not commit here and this load will commit in the same group as
// the first load of A.
FragmentScale tb_frag_scales;
tb_frag_scales.clear();
iterator_scale.load(tb_frag_scales);
this->smem_iterator_scale_.store(tb_frag_scales);
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
{
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
{
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
// so that all accumulator elements outside the GEMM footprint are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
{
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
{
typename IteratorA::AccessType* dst_ptr
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
{
typename IteratorB::AccessType* dst_ptr
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);)
{
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
{
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1))
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
}
else
{
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1))
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
smem_read_stage_idx = 0;
}
else
{
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
}
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
{
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}