void operator()

in source/backend/cuda/execution/plugin/FmhaCommon/fused_multi_head_attention/fmha_grouped.h [494:884]


  void operator()(Params const &params, SharedStorage &shared_storage) {
    auto& m_prime = shared_storage.m_prime;
    auto& s_prime = shared_storage.s_prime;
    [[maybe_unused]] auto& si = shared_storage.after_mm0.si;
    auto& mi = shared_storage.mi;
    auto& out_rescale = shared_storage.out_rescale;

    ProblemVisitor problem_visitor(
      params.problem_visitor,
      shared_storage.problem_visitor,
      blockIdx.x);

    // Outer 'persistent' loop to iterate over tiles
    while (problem_visitor.next_tile()) {

      GemmCoord problem_size0 = problem_visitor.problem_size0();
      GemmCoord problem_size1 = problem_visitor.problem_size1();
      const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());

      if (!TileParams::can_compute(threadblock_idx, problem_size0)) {
        problem_visitor.advance(gridDim.x);
        continue;
      }

      const int32_t problem_idx = problem_visitor.problem_index();

      if (thread_id() < kQueriesPerBlock) {
        s_prime[thread_id()] = ElementAccumulator(0);
        out_rescale[thread_id()] = accum_t(1.0);
        m_prime[thread_id()] =
            -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
        mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
      }

      ElementO *ptr_O = params.ptr_O[problem_idx]  + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
      ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx]  + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx];
      const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0);

      auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
        using OutputTileIterator = typename MM1::OutputTileIterator;
        return OutputTileIterator(
            typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]},
            ptr_O,
            typename OutputTileIterator::TensorCoord{
                num_queries, problem_size1.n()},
            thread_id(),
            {0, col});
      };

      auto createOutputAccumIter = [&](int col) ->
        typename MM1::OutputTileIteratorAccum {
          using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
          return OutputTileIteratorAccum(
              typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]},
              ptr_O_accum,
              typename OutputTileIteratorAccum::TensorCoord{
                  num_queries, problem_size1.n()},
              thread_id(),
              {0, col});
        };

      typename MM1::Mma::FragmentC accum_o;
      accum_o.clear();

      const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal);

      for (int32_t iter_key_start = 0; iter_key_start < num_keys;
           iter_key_start += kKeysPerBlock) {
        int32_t problem_size_0_m =
            cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries);
        int32_t problem_size_0_n = cutlass::fast_min(
            (int32_t)kKeysPerBlock, num_keys - iter_key_start);
        int32_t const& problem_size_0_k = problem_size0.k();
        int32_t const& problem_size_1_n = problem_size1.n();
        int32_t const& problem_size_1_k = problem_size_0_n;

        auto prologueV = [&](int blockN) {
          typename MM1::Mma::IteratorB iterator_V(
              typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
              params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
              {problem_size_1_k, problem_size_1_n},
              thread_id(),
              cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});

          MM1::Mma::prologue(
              shared_storage.after_mm0.mm1,
              iterator_V,
              thread_id(),
              problem_size_1_k);
        };

        __syncthreads(); // Need to have shared memory initialized, and `m_prime`
                         // updated from end of prev iter

        //
        // MATMUL: Q.K_t
        //
        // Computes the block-matrix product of:
        // (a) query[query_start:query_end, :]
        // with
        // (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
        // and stores that into `shared_storage.si`
        //

        ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx];

        // Construct iterators to A and B operands
        typename MM0::IteratorA iterator_A(
          typename MM0::IteratorA::Params(
              typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])),
          ptr_Q,
          {problem_size_0_m, problem_size_0_k},
          thread_id(),
          {0, 0});

        typename MM0::IteratorB iterator_B(
            typename MM0::IteratorB::Params(
                typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])),
            params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx],
            {problem_size_0_k, problem_size_0_n},
            thread_id(),
            {0, 0});

        // Construct thread-scoped matrix multiply
        typename MM0::Mma mma(
            shared_storage.mm0, thread_id(), warp_id(), lane_id());

        typename MM0::Mma::FragmentC accum;

        accum.clear();

        auto gemm_k_iterations =
            (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;

        // Compute threadblock-scoped matrix multiply-add
        mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
        __syncthreads();

        if (kPreloadV) {
          prologueV(0);
        } else {
          MM1::Mma::drain_cp_asyncs();
        }

        typename MM0::Mma::Operator::IteratorC::TensorCoord
          iteratorC_tile_offset = {
              (warp_id() % MM0::Mma::WarpCount::kM),
              (warp_id() / MM0::Mma::WarpCount::kM)
            };

        // Mask out last if causal
        if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) {
          auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
              lane_id(), warp_id(), iteratorC_tile_offset);
          int32_t last_col;
          MM0::AccumLambdaIterator::iterateRows(
              lane_offset,
              [&](int accum_m) {
                last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start;
              },
              [&](int accum_m, int accum_n, int idx) {
                if (accum_n > last_col) {
                  accum[idx] =
                      -cutlass::platform::numeric_limits<accum_t>::infinity();
                }
              },
              [&](int accum_m) {});
        }
        // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
        //         DISPATCH_BOOL(
        //             num_keys - iter_key_start >= kKeysPerBlock,
        //             kFullColumns,
        //             ([&] {
        //               // Update `mi` from accum stored in registers
        //               // Also does accum[i] <- exp(accum[i] - mi)
        //               iterative_softmax<
        //                   typename MM0::Mma::Operator::IteratorC,
        //                   kFullColumns,
        //                   kIsFirst>(
        //                   accum_o,
        //                   accum,
        //                   mi,
        //                   m_prime,
        //                   s_prime,
        //                   lane_id(),
        //                   thread_id(),
        //                   warp_id(),
        //                   num_keys - iter_key_start,
        //                   iteratorC_tile_offset,
        //                   kSupportsBias ? 1.0f : params.scale);
        //             }));
        //       }));

        // Update `mi` from accum stored in registers
        // Also does accum[i] <- exp(accum[i] - mi)
        iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
            accum_o,
            accum,
            mi,
            m_prime,
            s_prime,
            out_rescale,
            shared_storage.addition_storage,
            lane_id(),
            thread_id(),
            warp_id(),
            num_keys - iter_key_start,
            iter_key_start == 0,
            iteratorC_tile_offset,
            kSupportsBias ? 1.0f : params.scale);

        // Output results to shared-memory
        int warp_idx_mn_0 = warp_id() %
            (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
        auto output_tile_coords = cutlass::MatrixCoord{
            warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
            warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};

        MM0::B2bGemm::accumToSmem(
            shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords);

        __syncthreads();

        //
        // MATMUL: Attn . V
        // Run the matmul `attn @ V` for a block of attn and V.
        // `attn` is read from shared memory (in `shared_storage_si`)
        // `V` is read from global memory (with iterator_B)
        //

        const int64_t nBlockN = kKeepOutputInRF ? 1
                                                : ceil_div(
                                                      (int64_t)problem_size_1_n,
                                                      int64_t(MM1::ThreadblockShape::kN));

        // Iterate over the N dimension of GEMM1
        for (int blockN = 0; blockN < nBlockN; ++blockN) {
          int gemm_k_iterations =
              (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;

          // Compute threadblock-scoped matrix multiply-add and store it in accum
          // (in registers)
          if (!kPreloadV) {
            __syncthreads(); // we share shmem between mma and epilogue
          }

          typename MM1::Mma::IteratorB iterator_V(
            typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])},
            params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx],
            {problem_size_1_k, problem_size_1_n},
            thread_id(),
            cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});

          typename MM1::Mma mma_pv(
            // operand A: Pij_dropped in shared memory
            shared_storage.after_mm0.si.accum_ref(),
            // operand B: shared memory staging area for Vj, which is loaded
            // from global memory
            shared_storage.after_mm0.mm1.operand_B_ref(),
            (int)thread_id(),
            (int)warp_id(),
            (int)lane_id());

          mma_pv.set_prologue_done(kPreloadV);
          if (!kKeepOutputInRF) {
            accum_o.clear();
          }

          mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
          __syncthreads();

          if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) {
            prologueV(blockN + 1);
          }

          if (!kKeepOutputInRF) {
            MM1::Mma::drain_cp_asyncs();
            DISPATCH_BOOL(
                iter_key_start == 0, kIsFirst, ([&] {
                  DISPATCH_BOOL(
                      (iter_key_start + kKeysPerBlock) >= num_keys,
                      kIsLast,
                      ([&] {
                        using DefaultEpilogue = typename MM1::DefaultEpilogue;
                        using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
                        using ElementCompute = typename DefaultOp::ElementCompute;
                        using EpilogueOutputOp = typename cutlass::epilogue::
                            thread::MemoryEfficientAttentionNormalize<
                                typename cutlass::platform::conditional<
                                    kIsLast,
                                    output_t,
                                    output_accum_t>::type,
                                output_accum_t,
                                DefaultOp::kCount,
                                typename DefaultOp::ElementAccumulator,
                                output_accum_t,
                                kIsFirst,
                                kIsLast,
                                cutlass::Array<ElementCompute, kQueriesPerBlock>>;
                        using Epilogue = typename cutlass::epilogue::threadblock::
                            EpiloguePipelined<
                                typename DefaultEpilogue::Shape,
                                typename MM1::Mma::Operator,
                                DefaultEpilogue::kPartitionsK,
                                typename cutlass::platform::conditional<
                                    kIsLast,
                                    typename MM1::OutputTileIterator,
                                    typename MM1::OutputTileIteratorAccum>::type,
                                typename DefaultEpilogue::
                                    AccumulatorFragmentIterator,
                                typename DefaultEpilogue::WarpTileIterator,
                                typename DefaultEpilogue::SharedLoadIterator,
                                EpilogueOutputOp,
                                typename DefaultEpilogue::Padding,
                                DefaultEpilogue::kFragmentsPerIteration,
                                true, // IterationsUnroll
                                typename MM1::OutputTileIteratorAccum // Read
                                                                      // iterator
                                >;

                        int col = blockN * MM1::Mma::Shape::kN;
                        auto source_iter = createOutputAccumIter(col);
                        auto dest_iter = gemm_kernel_utils::call_conditional<
                            kIsLast,
                            decltype(createOutputIter),
                            decltype(createOutputAccumIter)>::
                            apply(createOutputIter, createOutputAccumIter, col);
                        EpilogueOutputOp rescale(s_prime, out_rescale);
                        Epilogue epilogue(
                            shared_storage.epilogue_shared_storage(),
                            thread_id(),
                            warp_id(),
                            lane_id());
                        epilogue(rescale, dest_iter, accum_o, source_iter);
                      }));
                }));
            if (!kKeepOutputInRF) {
              __syncthreads();
            }
          }
        }
         __syncthreads(); // we modify `m_prime` after
      }

      if (kKeepOutputInRF) {
        const bool kIsFirst = true;
        const bool kIsLast = true;
        using DefaultEpilogue = typename MM1::DefaultEpilogue;
        using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
        using ElementCompute = typename DefaultOp::ElementCompute;
        using EpilogueOutputOp =
            typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
                output_t,       // output
                output_accum_t, // source
                DefaultOp::kCount,
                typename DefaultOp::ElementAccumulator, // accum
                output_accum_t, // compute
                kIsFirst,
                kIsLast,
                cutlass::Array<ElementCompute, kQueriesPerBlock>>;
        using Epilogue =
            typename cutlass::epilogue::threadblock::EpiloguePipelined<
                typename DefaultEpilogue::Shape,
                typename MM1::Mma::Operator,
                DefaultEpilogue::kPartitionsK,
                typename MM1::OutputTileIterator, // destination
                typename DefaultEpilogue::AccumulatorFragmentIterator,
                typename DefaultEpilogue::WarpTileIterator,
                typename DefaultEpilogue::SharedLoadIterator,
                EpilogueOutputOp,
                typename DefaultEpilogue::Padding,
                DefaultEpilogue::kFragmentsPerIteration,
                true, // IterationsUnroll
                typename MM1::OutputTileIteratorAccum // source tile
                >;
        auto dest_iter = createOutputIter(0);
        EpilogueOutputOp rescale(s_prime, out_rescale);
        Epilogue epilogue(
            shared_storage.epilogue_shared_storage(),
            thread_id(),
            warp_id(),
            lane_id());
        MM1::Mma::drain_cp_asyncs();
        epilogue(rescale, dest_iter, accum_o);
      }

      // Next tile
      problem_visitor.advance(gridDim.x);
      __syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
    }
  }