in maga_transformer/cpp/cutlass/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp [185:424]
CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl, cute::Tensor<FrgEngine, FrgLayout> const& accumulators, TiledMma tiled_mma,
ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
auto synchronize = [&]()
{ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
auto mma_tile_m = tile_size<0>(tiled_mma);
auto mma_tile_n = tile_size<1>(tiled_mma);
auto epi_tile_m = size<0>(EpilogueTile{});
auto epi_tile_n = size<1>(EpilogueTile{});
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
// Batches are managed by using appropriate pointers to C and D matrices
int32_t const mock_L = 1;
int32_t const mock_l_coord = 0;
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp epilogue_op(params.thread, l_coord);
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{});
Tensor sD = as_position_independent_swizzle_tensor(sD_);
// Function to scatter output rows
auto& num_rows = params.num_rows_in_final_output;
auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord]));
auto get_scatter_idx = [&](auto i)
{
auto scatter = read_scatter_map(i);
int quot, rem;
num_rows(quot, rem, scatter);
return rem;
};
// Represent the full output tensor
ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr;
auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{};
Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l)
Tensor mD_mnl = make_gather_tensor(
make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l)
// Use fake shape for bias, it doesn't matter
bool const is_bias_needed = params.ptr_bias != nullptr;
Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias);
Tensor mScale_mnl = make_tensor(
make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale);
Tensor gC_mnl
= local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl
= local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N)
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gBias_mnl
= local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gScale_mnl
= local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N)
Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N)
Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy tiled_copy_C_atom
= make_tiled_copy_C_atom(Copy_Atom<SM90_U32x4_STSM_N, cutlass::half_t>{}, tiled_mma);
TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom);
auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx);
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N)
Tensor tRS_rD = make_tensor<ElementAccumulator>(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto tiled_s2r = [&]()
{
if constexpr (cutlass::gemm::detail::is_k_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_n / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMinor>, Int<NumThreadsMajor>>, Stride<Int<NumThreadsMajor>, _1>>{},
Layout<Shape<_1, Int<AlignmentD>>>{});
}
else if constexpr (cutlass::gemm::detail::is_mn_major<StrideD>())
{
constexpr int NumThreadsMajor = epi_tile_m / AlignmentD;
constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor;
return make_tiled_copy(CopyAtomS2R{},
Layout<Shape<Int<NumThreadsMajor>, Int<NumThreadsMinor>>, Stride<_1, Int<NumThreadsMajor>>>{},
Layout<Shape<Int<AlignmentD>, _1>>{});
}
else
{
static_assert(cute::is_void_v<StrideD>, "Unsupported D gmem layout.");
}
}();
auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx);
Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor tSR_rD = make_tensor<ElementAccumulator>(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rD_final = make_tensor<ElementD>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rC = make_tensor<ElementC>(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rBias = make_tensor<ElementBias>(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
Tensor tSR_rScale = make_tensor<ElementScale>(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m)
{
CUTLASS_PRAGMA_UNROLL
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n)
{
int mma_m = (epi_m * epi_tile_m) / mma_tile_m;
int mma_n = (epi_n * epi_tile_n) / mma_tile_n;
Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n);
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
int r2s_v = epi_n_in_mma * size(tRS_rD);
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v)
{
tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v);
}
copy(tiled_r2s, tRS_rD, tRS_sD);
synchronize();
copy(tiled_s2r, tSR_sD, tSR_rD);
synchronize();
Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n);
Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n);
Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n);
Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n);
Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n);
if (epilogue_op.is_source_needed())
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n));
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < size<1>(tSR_rD); ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < size<2>(tSR_rD); ++n)
{
if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk))))
{
if (is_bias_needed)
{
copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n));
}
copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size<0>(tSR_rD); ++i)
{
auto epi_value = epilogue_op(tSR_rD(i, m, n));
if (is_bias_needed)
{
epi_value += static_cast<ElementCompute>(tSR_rBias(i, m, n));
}
tSR_rD_final(i, m, n) = static_cast<ElementD>(tSR_rScale(i, m, n) * epi_value);
}
copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n));
}
}
}
}
}
}
}