void operator()

in maga_transformer/cpp/cutlass/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h [371:682]


    void operator()(
        ///< problem size of GEMM
        int gemm_k_iterations,
        ///< destination accumulator tile
        FragmentC& accum,
        ///< iterator over A operand in global memory
        IteratorA iterator_A,
        ///< iterator over B operand in global memory
        IteratorB iterator_B,
        ///< iterator over scale operand in global memory
        IteratorScale iterator_scale,
        ///< initial value of accumulator
        FragmentC const& src_accum)
    {

        //
        // Prologue
        //

        TransformBAfterLDS lds_converter;

        // Issue several complete stages
        CUTLASS_PRAGMA_UNROLL
        for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
        {

            iterator_A.clear_mask(gemm_k_iterations == 0);
            iterator_B.clear_mask(gemm_k_iterations == 0);
            iterator_scale.clear_mask(gemm_k_iterations == 0);

            iterator_A.set_iteration_index(0);
            this->smem_iterator_A_.set_iteration_index(0);

            // Async Copy for operand A
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
            {
                typename IteratorA::AccessType* dst_ptr
                    = reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
                {
                    int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
                        * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;

                    int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
                        dst_ptr + v, iterator_A.get(), iterator_A.valid());

                    ++iterator_A;
                }

                ++this->smem_iterator_A_;
            }

            iterator_B.set_iteration_index(0);
            this->smem_iterator_B_.set_iteration_index(0);

            // Async Copy for operand B
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
            {
                typename IteratorB::AccessType* dst_ptr
                    = reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
                {
                    int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
                        * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
                        dst_ptr + v, iterator_B.get(), iterator_B.valid());

                    ++iterator_B;
                }

                ++this->smem_iterator_B_;
            }

            copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);

            // Move to the next stage
            iterator_A.add_tile_offset({0, 1});
            iterator_B.add_tile_offset({1, 0});

            this->smem_iterator_A_.add_tile_offset({0, 1});
            this->smem_iterator_B_.add_tile_offset({1, 0});

            // Defines the boundary of a stage of cp.async.
            cutlass::arch::cp_async_fence();
        }

        // Perform accumulation in the 'd' output operand
        accum = src_accum;

        //
        // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
        // so that all accumulator elements outside the GEMM footprint are zero.
        //

        if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
        {

            /// Iterator to write threadblock-scoped tile of A operand to shared memory
            SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);

            typename IteratorA::AccessType zero_A;
            zero_A.clear();

            last_smem_iterator_A.set_iteration_index(0);

            // Async Copy for operand A
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
            {

                typename IteratorA::AccessType* dst_ptr
                    = reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());

                *dst_ptr = zero_A;

                ++last_smem_iterator_A;
            }

            /// Iterator to write threadblock-scoped tile of B operand to shared memory
            SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
            typename IteratorB::AccessType zero_B;

            zero_B.clear();
            last_smem_iterator_B.set_iteration_index(0);

            // Async Copy for operand B
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
            {

                typename IteratorB::AccessType* dst_ptr
                    = reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());

                *dst_ptr = zero_B;

                ++last_smem_iterator_B;
            }
        }

        // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
        cutlass::arch::cp_async_wait<Base::kStages - 2>();
        __syncthreads();

        // Pair of fragments used to overlap shared memory loads and math
        // instructions
        WarpFragmentA warp_frag_A[2];
        WarpFragmentB warp_frag_B[2];
        typename Dequantizer::FragmentScale warp_frag_scales;
        typename Dequantizer::FragmentZero warp_frag_zeros;

        Operator warp_mma;

        this->warp_tile_iterator_A_.set_kgroup_index(0);
        this->warp_tile_iterator_B_.set_kgroup_index(0);

        this->warp_tile_iterator_A_.load(warp_frag_A[0]);
        this->warp_tile_iterator_B_.load(warp_frag_B[0]);

        warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);

        ++this->warp_tile_iterator_A_;
        ++this->warp_tile_iterator_B_;
        warp_dequantizer_.add_pointer_offset(Shape::kN);

        iterator_A.clear_mask(gemm_k_iterations == 0);
        iterator_B.clear_mask(gemm_k_iterations == 0);
        iterator_scale.clear_mask(gemm_k_iterations == 0);

        int smem_write_stage_idx = Base::kStages - 1;
        int smem_read_stage_idx = 0;

        //
        // Mainloop
        //

        CUTLASS_GEMM_LOOP
        for (; gemm_k_iterations > (-Base::kStages + 1);)
        {
            //
            // Loop over GEMM K dimension
            //

            // Computes a warp-level GEMM on data held in shared memory
            // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
            CUTLASS_PRAGMA_UNROLL
            for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
            {

                // Load warp-level tiles from shared memory, wrapping to k offset if
                // this is the last group as the case may be.

                this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
                this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
                ++this->warp_tile_iterator_A_;

                const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
                const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
                if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
                {
                    this->warp_tile_iterator_B_.set_kgroup_index(
                        (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
                    this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
                    ++this->warp_tile_iterator_B_;
                }

                typename TransformBAfterLDS::result_type converted_frag_B
                    = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
                warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);

                run_warp_mma(
                    warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);

                // Issue global->shared copies for the this stage
                if (warp_mma_k < Base::kWarpGemmIterations - 1)
                {
                    int group_start_iteration_A, group_start_iteration_B;

                    group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
                    group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;

                    copy_tiles_and_advance(
                        iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);

                    // This is the first group of a given stage, so we issue the loads for the B scales immediately.
                    if (group_start_iteration_B == 0)
                    {
                        copy_scales_and_advance(iterator_scale);
                    }
                }

                if (warp_mma_k + 2 == Base::kWarpGemmIterations)
                {
                    int group_start_iteration_A, group_start_iteration_B;
                    group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
                    group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;

                    copy_tiles_and_advance(
                        iterator_A, iterator_B, iterator_scale, group_start_iteration_A, group_start_iteration_B);

                    // Inserts a memory fence between stages of cp.async instructions.
                    cutlass::arch::cp_async_fence();

                    // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
                    // #committed)
                    arch::cp_async_wait<Base::kStages - 2>();
                    __syncthreads();

                    // Move to the next stage
                    iterator_A.add_tile_offset({0, 1});
                    iterator_B.add_tile_offset({1, 0});

                    this->smem_iterator_A_.add_tile_offset({0, 1});
                    this->smem_iterator_B_.add_tile_offset({1, 0});

                    // Add negative offsets to return iterators to the 'start' of the
                    // circular buffer in shared memory
                    if (smem_write_stage_idx == (Base::kStages - 1))
                    {
                        this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
                        this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
                        this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
                        smem_write_stage_idx = 0;
                    }
                    else
                    {
                        ++smem_write_stage_idx;
                    }

                    if (smem_read_stage_idx == (Base::kStages - 1))
                    {
                        this->warp_tile_iterator_A_.add_tile_offset(
                            {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
                        this->warp_tile_iterator_B_.add_tile_offset(
                            {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
                        warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
                        smem_read_stage_idx = 0;
                    }
                    else
                    {
                        ++smem_read_stage_idx;
                    }

                    --gemm_k_iterations;
                    iterator_A.clear_mask(gemm_k_iterations == 0);
                    iterator_B.clear_mask(gemm_k_iterations == 0);
                    iterator_scale.clear_mask(gemm_k_iterations == 0);
                }
            }

            // Load the scale needed for the next tile iteration.
            warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
            // Update internal pointer to set of scales in shared memory.
            warp_dequantizer_.add_pointer_offset(Shape::kN);
        }

        if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
        {
            // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
            cutlass::arch::cp_async_fence();
            cutlass::arch::cp_async_wait<0>();
            __syncthreads();
        }
    }