CUTLASS_HOST_DEVICE void operator()

in maga_transformer/cpp/cutlass/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp [185:424]


    CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
        BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
        ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
    {
        using namespace cute;
        using X = Underscore;

        static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
        static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
        static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
        static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");

        auto synchronize = [&]()
        { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };

        // Separate out problem shape for convenience
        auto M = get<0>(problem_shape_mnkl);
        auto N = get<1>(problem_shape_mnkl);
        auto L = get<3>(problem_shape_mnkl);

        auto mma_tile_m = tile_size<0>(tiled_mma);
        auto mma_tile_n = tile_size<1>(tiled_mma);
        auto epi_tile_m = size<0>(EpilogueTile{});
        auto epi_tile_n = size<1>(EpilogueTile{});

        CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
        CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");

        // Batches are managed by using appropriate pointers to C and D matrices
        int32_t const mock_L = 1;
        int32_t const mock_l_coord = 0;

        // Slice to get the tile this CTA is responsible for
        auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;

        // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
        // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
        // we get the correct alpha/beta values for the current batch/group using group index.
        ThreadEpilogueOp epilogue_op(params.thread, l_coord);

        SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);

        Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
        Tensor sD = as_position_independent_swizzle_tensor(sD_);

        // Function to scatter output rows
        auto& num_rows = params.num_rows_in_final_output;
        auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
        auto get_scatter_idx = [&](auto i)
        {
            auto scatter = read_scatter_map(i);
            int quot, rem;
            num_rows(quot, rem, scatter);
            return rem;
        };

        // Represent the full output tensor
        ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
        auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
        Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC);        // (m,n,l)
        Tensor mD_mnl = make_gather_tensor(
            make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)

        // Use fake shape for bias, it doesn't matter
        bool const is_bias_needed = params.ptr_bias != nullptr;
        Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
        Tensor mScale_mnl = make_tensor(
            make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);

        Tensor gC_mnl
            = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
        Tensor gD_mnl
            = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)

        Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord);                        // (BLK_M,BLK_N)
        Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord);                        // (BLK_M,BLK_N)

        Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
        Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)

        Tensor gBias_mnl
            = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{});  // (BLK_M,BLK_N,m,n,l)
        Tensor gScale_mnl
            = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)

        Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord);                           // (BLK_M,BLK_N)
        Tensor gScale = gScale_mnl(_, _, m_coord, n_coord);                                  // (BLK_M,BLK_N)

        Tensor gBias_epi = flat_divide(gBias, EpilogueTile{});   // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
        Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)

        // Get the smallest tiled copy we can use to retile the accumulators
        TiledCopy tiled_copy_C_atom
            = make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
        TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);

        auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
        Tensor tRS_rAcc = thread_r2s.retile_S(accumulators);            // ((R2S,R2S_V),MMA_M,MMA_N)
        Tensor tRS_sD = thread_r2s.partition_D(sD);                     // ((R2S,R2S_V),R2S_M,R2S_N)
        Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)

        // Make a tiled copy vectorized along major direction of D
        auto tiled_s2r = [&]()
        {
            if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
            {
                constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
                constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
                return make_tiled_copy(CopyAtomS2R{},
                    Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
                    Layout<Shape<_1, Int<AlignmentD>>>{});
            }
            else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
            {
                constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
                constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
                return make_tiled_copy(CopyAtomS2R{},
                    Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
                    Layout<Shape<Int<AlignmentD>, _1>>{});
            }
            else
            {
                static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
            }
        }();

        auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
        Tensor tSR_sD = thread_s2r.partition_S(sD);             // ((S2R,S2R_V),S2R_M,S2R_N)
        Tensor tSR_gD = thread_s2r.partition_D(gD_epi);         // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
        Tensor tSR_gC = thread_s2r.partition_D(gC_epi);         // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
        Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi);   // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
        Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)

        // Allocate intermediate registers for a single subtile
        Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD)));        // ((S2R,S2R_V),S2R_M,S2R_N)
        Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD));                        // ((S2R,S2R_V),S2R_M,S2R_N)
        Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD));                              // ((S2R,S2R_V),S2R_M,S2R_N)
        Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout());    // ((S2R,S2R_V),S2R_M,S2R_N)
        Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)

        // Make an identity coordinate tensor for predicating our output MN tile
        Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
        Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
        Tensor tSR_cD = thread_s2r.partition_D(cD_epi);  // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)

        // epilogue subtile loop
        CUTLASS_PRAGMA_UNROLL
        for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
        {
            CUTLASS_PRAGMA_UNROLL
            for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
            {
                int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
                int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
                Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);

                int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
                int r2s_v = epi_n_in_mma * size(tRS_rD);
                CUTLASS_PRAGMA_UNROLL
                for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
                {
                    tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
                }

                copy(tiled_r2s, tRS_rD, tRS_sD);
                synchronize();

                copy(tiled_s2r, tSR_sD, tSR_rD);
                synchronize();

                Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
                Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
                Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
                Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
                Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);

                if (epilogue_op.is_source_needed())
                {
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < size<1>(tSR_rD); ++m)
                    {
                        CUTLASS_PRAGMA_UNROLL
                        for (int n = 0; n < size<2>(tSR_rD); ++n)
                        {
                            if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
                            {
                                copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
                                if (is_bias_needed)
                                {
                                    copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
                                }
                                copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
                                CUTLASS_PRAGMA_UNROLL
                                for (int i = 0; i < size<0>(tSR_rD); ++i)
                                {
                                    auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
                                    if (is_bias_needed)
                                    {
                                        epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
                                    }
                                    tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
                                }
                                copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
                            }
                        }
                    }
                }
                else
                {
                    CUTLASS_PRAGMA_UNROLL
                    for (int m = 0; m < size<1>(tSR_rD); ++m)
                    {
                        CUTLASS_PRAGMA_UNROLL
                        for (int n = 0; n < size<2>(tSR_rD); ++n)
                        {
                            if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
                            {
                                if (is_bias_needed)
                                {
                                    copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
                                }
                                copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
                                CUTLASS_PRAGMA_UNROLL
                                for (int i = 0; i < size<0>(tSR_rD); ++i)
                                {
                                    auto epi_value = epilogue_op(tSR_rD(i, m, n));
                                    if (is_bias_needed)
                                    {
                                        epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
                                    }
                                    tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
                                }
                                copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
                            }
                        }
                    }
                }
            }
        }
    }