void run_fmha_fwd_loop()

in kernels/fmha_fwd_launch_template.h [47:91]


void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
    constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
    const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;

    constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
    // Don't need smem_size_softmax_lse if we're not looping
    const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
        + (loop_steps > 1 ? smem_size_softmax_lse : 0);

    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
    // https://github.com/kokkos/kokkos-kernels/issues/349
    // https://github.com/HazyResearch/flash-attention/issues/21
    BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] {
        auto kernel = launch_params.params.is_causal
            ? (launch_params.return_softmax
               ? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
               : &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
            : (launch_params.return_softmax
               ? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
               : &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
        if( smem_size >= 48 * 1024 ) {
            FMHA_CHECK_CUDA(cudaFuncSetAttribute(
                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
        }
        // Automatically set num_splits to maximize occupancy
        if (launch_params.params.num_splits <= 0) {
            int ctas_per_sm;
            cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
                &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
//            auto dprops = at::cuda::getCurrentDeviceProperties();
            // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
            constexpr int M = Kernel_traits::Cta_tile_p::M;
            launch_params.params.num_splits = num_splits_heuristic_fwd(
                launch_params.params.b * launch_params.params.h, launch_params.multi_processor_count,
                ctas_per_sm,
                /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
            );
        }
        // printf("smem_size = %d\n", smem_size);
        dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
        kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
            launch_params.params);
        FMHA_CHECK_CUDA(cudaPeekAtLastError());
    }));
}