in source/backend/cuda/execution/plugin/FmhaCommon/fused_multi_head_attention/kernel_forward.h [654:1152]
static void CUTLASS_DEVICE attention_kernel(Params& p) {
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
p.output_ptr,
typename OutputTileIterator::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{
(int32_t)(p.head_dim_value * p.num_heads)},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
#ifdef HAS_PYTORCH
curandStatePhilox4_32_10_t curand_state_init;
if (kSupportsDropout && p.use_dropout) {
const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
// each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that
// starts at the beginning of a (n_queries, n_keys) matrix for this
// block's batch_id and head_id
// initializing rng state is very expensive, so we run once per kernel,
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(
std::get<0>(seeds),
0,
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
&curand_state_init);
}
#endif
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_uniform(warp_id());
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// multiply by scaling factor
if (kSupportsBias) {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}
// apply attention bias if applicable
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
// load bias tile Bij into shared memory
typename MM0::BiasLoader::GmemTileIterator bias_iter(
{cutlass::layout::RowMajor(p.bias_strideM)},
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
{problem_size_0_m, problem_size_0_n},
thread_id());
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
shared_storage.after_mm0.bias.data(),
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
bias_tensor_ref, thread_id());
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
my_lane_id, my_warp_id, iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
}
},
[&](int accum_m) {});
}
// Mask out last if causal
// This is only needed if upper-right corner of current query / key block
// intersects the mask Coordinates of upper-right corner of current block
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
// first masked element is x = y + offset -> query_start + offset There is
// intersection (and we need to mask) if min(iter_key_start +
// kKeysPerBlock, num_keys)) >= query_start + offset
if (p.custom_mask_type &&
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
(query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
my_lane_id, my_warp_id, iteratorC_tile_offset);
int32_t last_col;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
// last absolute col is (last absolute query + offset)
// last local col is (last absolute query + offset -
// iter_key_start)
last_col = query_start + accum_m + p.causal_diagonal_offset -
iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
my_lane_id,
thread_id(),
my_warp_id,
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
#ifdef HAS_PYTORCH
// apply dropout (if applicable) after we've written Pij to smem.
// dropout is applied by multiplying each element of Pij by:
// - 0 with probability dropout_p
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
//
// for backward purposes we want to be able to map each element of the
// attention matrix to the same random uniform number as the one we used
// in forward, without needing to use the same iteration order or having
// to store the dropout matrix. its possible to do this in registers but
// it ends up being very slow because each thread having noncontiguous
// strips of the Pij tile means we have to skip around a lot, and also
// have to generate a single random number at a time
if (kSupportsDropout && p.use_dropout) {
auto si = shared_storage.after_mm0.si.accum_ref();
// each thread handles a contiguous sequence of elements from Sij, all
// coming from the same row. the reason they have to come from the same
// row is that the sampling random numbers from a contiguous random
// number sequence is much more efficient than jumping around, and the
// linear offset of each element of S (the global matrix) maps to an
// offset in a random number sequence. for S, the end of a row and the
// beginning of the next have adjacent offsets, but for Sij, this is not
// necessarily the case.
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
const int threads_per_row =
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
const int elts_per_thread = cutlass::round_nearest(
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
const int thread_i = thread_id() / threads_per_row;
const int thread_start_j =
(thread_id() % threads_per_row) * elts_per_thread;
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(
static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys_absolute +
(iter_key_start + thread_start_j)),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
// apply dropout scaling to elements this thread is responsible for,
// in chunks of 4
for (int sij_start_col_idx = thread_start_j; sij_start_col_idx <
cutlass::fast_min(thread_start_j + elts_per_thread,
problem_size_0_n);
sij_start_col_idx += 4) {
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
CUTLASS_PRAGMA_UNROLL
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
static_cast<scalar_t>(
dropout_scale *
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
}
}
}
__syncthreads(); // p.use_dropout should have same value kernel-wide
}
#endif
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kSingleValueIteration
? 1
: ceil_div(
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)my_warp_id,
(int)my_lane_id);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp =
typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
my_warp_id,
my_lane_id);
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}