__global__ void masked_multihead_attention_kernel()

in maga_transformer/cpp/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.h [1208:2542]


__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params,
                                                  KVCacheBuffer                                     kvCacheBuffer)
{

    using Tk = typename kernel_type_t<T>::Type;
    // Use 8bit cache.
    static constexpr bool ENABLE_8BITS_CACHE = sizeof(Tcache) == 1;
    // FP8 KV Cache.
#ifdef ENABLE_FP8
    static constexpr bool FP8_KV_CACHE = std::is_same<Tcache, __nv_fp8_e4m3>::value;
#else
    static constexpr bool FP8_KV_CACHE = false;
#endif
    // INT8 KV Cache.
    static constexpr bool INT8_KV_CACHE = std::is_same<Tcache, int8_t>::value;

    // The size of a warp.
    constexpr unsigned WARP_SIZE{32};
    // The number of warps in a threadblock.
    constexpr unsigned WARPS_PER_BLOCK{THREADS_PER_BLOCK / WARP_SIZE};

    // The maximum hidden size per head.
    constexpr auto Dh_MAX    = dh_max(Dh);
    constexpr bool IS_Dh_MAX = Dh == Dh_MAX;
    static_assert(Dh_MAX >= WARP_SIZE);
    static_assert(Dh_MAX >= Dh);

    // The maximum sequence length in the cyclic kv_cache, i.e., an upper bound on L.
    // Note that the maximum sequence length supported by the model might be greater than this.
    // Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
    // By default, you can assume that they are the same.
    const auto cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_kv_cache_length);
    // The current timestep (including paddings).
    // It is only used to calculate the smem stride.
    const auto timestep = static_cast<unsigned>(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep);

#ifdef ENABLE_MULTI_BLOCK_OPTION
    constexpr bool MULTI_BLOCK_FLAG = DO_MULTI_BLOCK;
#else
    constexpr bool MULTI_BLOCK_FLAG = false;
#endif

    // Use smem_size_in_bytes (above) to determine the amount of shared memory.
    extern __shared__ char smem_[];

    // The shared memory for the Q*K^T values and partial logits in softmax.
    auto qk_smem = reinterpret_cast<float*>(smem_);

    __shared__ float qk_current_smem[1];

    // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
    char* logits_smem_ = smem_;
#ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS
    if (sizeof(Tk) != 4) {
        const auto max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len);
        logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
    }
    Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
#else
    float* logits_smem = reinterpret_cast<float*>(logits_smem_);
#endif

    __shared__ Tk logits_current_smem[1];

    // The shared memory to do the final reduction for the output values. Reuse qk_smem.
    Tk* out_smem = reinterpret_cast<Tk*>(smem_);

    // The shared memory buffers for the block-wide reductions. One for max, one for sum.
    __shared__ float red_smem[WARPS_PER_BLOCK * 2];

    // A vector of Q or K elements for the current timestep.
    using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type;  // with memory-used precision
    using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type;  // with kernel-used precision
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
    using Qk_vec_accum = typename Qk_vec_accum_fp32_<Qk_vec_k>::Type;
#else
    using Qk_vec_accum = Qk_vec_k;
#endif

    // Make sure the hidden dimension per head is a multiple of the number of threads per key.
    static_assert(Dh_MAX % THREADS_PER_KEY == 0);  // trivially satisfied since THREADS_PER_KEY in {1, 2, 4}

    // The number of elements per vector.
    // Each thread will handle 16 bytes.
    constexpr int K_VEC_SIZE = 16u / sizeof(T);
    // Make sure the hidden size per head is a multiple of the vector size.
    static_assert(Dh_MAX % K_VEC_SIZE == 0);
    // The type of queries and keys for the math in the Q*K^T product.
    using K_vec_k = typename K_vec_k_<T, K_VEC_SIZE>::Type;
    // Only used when key cache is quantized to 8 bits.
    using K_vec_m = typename packed_type<Tcache, num_elems<K_vec_k>::value>::type;
#ifdef MMHA_USE_FP32_ACCUM_FOR_FMA
    using K_vec_accum = typename Qk_vec_accum_fp32_<K_vec_k>::Type;
#else
    using K_vec_accum = K_vec_k;
#endif

    // Use alignment for safely casting the shared buffers as Qk_vec_k and K_vec_k.
    // Shared memory to store Q inputs.
    __shared__ __align__(const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk q_smem[Dh_MAX];
    __shared__ __align__(const_max(sizeof(Qk_vec_k), sizeof(K_vec_k))) Tk k_smem[Dh_MAX];

    // Make sure the hidden dimension per head is a multiple of the number of threads per value.
    static_assert(Dh_MAX % THREADS_PER_VALUE == 0);  // trivially satisfied since THREADS_PER_VALUE == Dh_MAX / p

    // The number of elements per vector.
    constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
    // A vector of V elements for the current timestep.
    using V_vec_k = typename V_vec_k_<T, V_VEC_SIZE>::Type;
    // Only used when value cache is quantized to 8 bits.
    using V_vec_m = typename packed_type<Tcache, num_elems<V_vec_k>::value>::type;
    static_assert(V_VEC_SIZE == sizeof(V_vec_k) / sizeof(T));

    // This could be one of the reasons to have a separate kernel for cross attention
    constexpr auto bias_smem_size = DO_CROSS_ATTENTION ? Dh_MAX : 1u;
    __shared__     __align__(const_max(const_max(sizeof(Qk_vec_k), sizeof(K_vec_k)), sizeof(V_vec_k)))
        Tk         bias_smem[bias_smem_size];

    // The number of elements per vector.
    constexpr unsigned QK_VEC_SIZE{sizeof(Qk_vec_m) / sizeof(T)};
    // Make sure the hidden size per head is a multiple of the vector size.
    static_assert(Dh_MAX % QK_VEC_SIZE == 0);
    // We will use block wide reduction if needed
    // The number of vectors per Dh_MAX.
    constexpr unsigned QK_VECS_PER_Dh_MAX{Dh_MAX / QK_VEC_SIZE};
    static_assert(THREADS_PER_BLOCK >= QK_VECS_PER_Dh_MAX);

    // The batch/beam idx
    const auto batch_beam_idx = blockIdx.y;
    if (params.finished != nullptr && params.finished[batch_beam_idx]) {
        return;
    }

    // The head.
    const unsigned hi{blockIdx.x};
    // The head index of keys and values adjusted for MQA/GQA.
    const int      qhead_per_kv{params.num_heads / params.num_kv_heads};
    const unsigned hi_kv{hi / qhead_per_kv};
    // The number of heads.
    const auto num_heads = static_cast<unsigned>(params.num_heads);
    // The number of heads for keys and values adjusted for MQA/GQA.
    const auto num_heads_kv = static_cast<unsigned>(params.num_kv_heads);

    // The thread in the block.
    const unsigned tidx{threadIdx.x};

    // The column tile along L dimension on K^T -- noted as T_c in flash-attention paper
    const unsigned c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0};

    // Indicate if we need to compute the K/V cache element (add KV bias, IA3, RoPE, etc.) and update the cache.
    // For Self-Attention, it's always required.
    // For Cross-Attention, as everything is pre-computed,
    // in the context phase of the encoder, it's not needed in that kernel.
    // Therefore, HANDLE_KV is !DO_CROSS_ATTENTION and irrelevant of timestep.
    static constexpr bool HANDLE_KV{!DO_CROSS_ATTENTION};

    // While doing the product Q*K^T for the different keys we track the max.
    float qk_max = -FLT_MAX;

    float qk = 0.0F;

    // Do we have a relative attention bias?
    bool has_relative_attention_bias = params.relative_attention_bias != nullptr;
    // Compute relative attention bias on the fly, with relative attention table [head_num/TP, num_buckets] passed in.
    // num_buckets passed as relative_attention_bias_stride, max_distance passed as params.max_distance
    // this is a common optimization for both self attention and cross attention
    const bool implicit_rel_attn_bias = params.max_distance != 0 && has_relative_attention_bias;
    int        relative_attention_bias_stride =
        params.relative_attention_bias_stride;  // num_buckets might be modified below, save it beforehand
    int max_distance = params.max_distance;

    // The actual sequence length excluding the paddings.
    // minus 1 because it includes the current timestep while tlength denotes the kv cache length.
    const int tlength             = DO_CROSS_ATTENTION ? params.memory_length_per_sample[batch_beam_idx] - 1 :
                                                         (params.length_per_sample ?
                                                              (params.length_per_sample[batch_beam_idx] + params.max_prefix_prompt_length) :
                                                              static_cast<int>(timestep));
    bool      count_prefix_length = params.count_prefix_length;
    // We will use cyclic kv cache when it exceeds the limit.
    // The length position for storing new key and value.
    int const cyclic_tlength = tlength % cyclic_kv_cache_len;
    // The actual kv cache length.
    // tlength is the past length actually.
    const int kv_loop_length = min(tlength, cyclic_kv_cache_len);
    // The context length for beam searching optimization (all points to beam 0).
    // TODO: with cyclic kv cache, we set it 0 for now (will optimize in the future)
    // as context kv cache might be overwritten by the new kv cache
    const int beam0_context_length =
        HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];

    // The offset in the Q and K buffer also accounts for the batch.
    const auto qk_vec_idx      = tidx * QK_VEC_SIZE;
    const auto is_valid_qk_vec = qk_vec_idx < Dh;

    const bool load_qkv_quant        = params.qkv_scale_quant_orig != nullptr;
    const bool write_attention_quant = params.attention_out_scale_orig_quant != nullptr;

    // Quant/Dequant scales for 8bits kv cache.
    using T_scale = typename kv_cache_scale_type_t<T, Tcache>::Type;
    T_scale     kv_scale_orig_quant, kv_scale_quant_orig;
    const float kv_scale_quant_orig_f = (ENABLE_8BITS_CACHE && params.kv_scale_quant_orig != nullptr ? params.kv_scale_quant_orig[0] : 1.0f);
    convert_from_float(&kv_scale_quant_orig, kv_scale_quant_orig_f);
    convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_CACHE && params.kv_scale_quant_orig != nullptr ? params.kv_scale_orig_quant[0] : 1.0f));

    // Up to QK_VECS_PER_Dh_MAX threads load Q and K + the bias values for the current timestep.
    // Trigger the loads from the Q and K buffers.
    Qk_vec_k q, k, q_bias, k_bias;
    zero(q);
    zero(k);
    zero(q_bias);
    zero(k_bias);
    if (is_valid_qk_vec) {
        // Query
        // The stride between tokens. We may be able to always use params.stride.
        uint32_t q_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads * Dh);
        // The offset.
        const auto q_offset = flat_index_strided3(batch_beam_idx, hi, qk_vec_idx, q_stride, Dh);

        if (load_qkv_quant) {
            using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
            using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
            const auto q_scaling = params.qkv_scale_quant_orig[0];
            const auto q_quant =
                *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
            convert_from_float(&q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
        }
        else {
            q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[q_offset]));
        }

        if constexpr (DO_CROSS_ATTENTION) {
            const auto k_idx      = QK_VEC_SIZE * tidx;
            const int  inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx);
            Tcache*    k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));

            k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&k_cache[inBlockIdx]));
        }
        else {
            // Key
            // The stride between tokens. We may be able to always use params.stride.
            uint32_t k_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
            // The offset.
            const auto k_offset = flat_index_strided3(batch_beam_idx, hi_kv, qk_vec_idx, k_stride, Dh);

            if (load_qkv_quant) {
                using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
                using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
                const auto k_scaling = params.qkv_scale_quant_orig[1];
                const auto k_quant =
                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);

                convert_from_float(&k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
            }
            else {
                k = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[k_offset]));
            }
        }

        if (params.q_bias != nullptr) {
            const auto q_bias_offset = flat_index2(hi, qk_vec_idx, Dh);
            q_bias =
                vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[q_bias_offset]));
        }
        if (HANDLE_KV && params.k_bias != nullptr) {
            const auto k_bias_offset = flat_index2(hi_kv, qk_vec_idx, Dh);
            k_bias =
                vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset]));
        }
    }

    // Computes the Q/K values with bias.
    q = add(q, q_bias);
    if (HANDLE_KV) {
        k = add(k, k_bias);
    }

    // The width of the beam.
    const auto beam_width = static_cast<unsigned>(params.beam_width);
    // The batch idx.
    const int batch_idx = batch_beam_idx / beam_width;
    // Do we apply IA3?
    const bool do_ia3 = HANDLE_KV && params.ia3_tasks != nullptr;
    // Compute the IA3 task. One per batch index.
    const auto ia3_ti_hi = do_ia3 ? flat_index2(static_cast<unsigned>(params.ia3_tasks[batch_idx]), hi, num_heads) : 0;

    if (do_ia3 && is_valid_qk_vec) {
        k = mul<Qk_vec_k, Qk_vec_k, Qk_vec_k>(k,
                                              vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(
                                                  &params.ia3_key_weights[flat_index2(ia3_ti_hi, qk_vec_idx, Dh)])));
    }
    const int input_len            = (params.input_lengths == nullptr) ? 0 : params.input_lengths[batch_beam_idx];
    int       prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[batch_beam_idx];
    const int position_id = params.position_ids == nullptr ? -1 : params.position_ids[batch_beam_idx * params.rope_config.index_factor];
    attention_rope<T, Qk_vec_k, ROPE_STYLE>(
            params.rope_config,
            q,
            k,
            reinterpret_cast<T*>(smem_),
            tidx,
            tlength,
            timestep,
            params.length_per_sample[batch_beam_idx],
            position_id,
            input_len,
            prefix_prompt_length,
            count_prefix_length,
            HANDLE_KV);

    __syncthreads();

    if (params.use_logn_attn && is_valid_qk_vec) {
        logn_attention(q, tlength, params.rope_config.max_pos);
    }

    // For the same reason as HANDLE_KV, no compute needed in Cross-Attention's 1st step
    // Store Q K vectors to shared memory, and calculate QK.
    if (qk_vec_idx < Dh_MAX) {

        // Store the Q values to shared memory.
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
        if constexpr (FP8_KV_CACHE) {
            // There are many more elements from K than elements from Q so we pre-scale Q instead
            // of scaling all the elements from K. It helps reduce the number of ops.
            Qk_vec_k scaled_q;
            zero(scaled_q);
            if (is_valid_qk_vec) {
                scaled_q = mul<Qk_vec_k, Tk, Qk_vec_k>(kv_scale_quant_orig, q);
            }
            reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = scaled_q;
        }
        else
#endif
        {
            // Set padded Dh to 0 for the correctness of QK (when Dh != Dh_Max).
            Qk_vec_k zero_q;
            zero(zero_q);
            reinterpret_cast<Qk_vec_k*>(&q_smem[qk_vec_idx])[0] = is_valid_qk_vec ? q : zero_q;
        }

        // Store the K values to shared memory.
        // We store K values from shared memory to global memory
        //  when the target position of K cache in global memory has been accessed (in the case of cyclic kv cache)
        reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx])[0] = k;

        // Compute \sum_i Q[i] * K^T[i] for the current timestep.
        qk = dot<Qk_vec_accum, Qk_vec_k>(q, k);
        if (QK_VECS_PER_Dh_MAX <= WARP_SIZE) {
#pragma unroll
            for (int mask = QK_VECS_PER_Dh_MAX / 2; mask >= 1; mask /= 2) {
                qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_Dh_MAX), qk, mask);
            }
        }
    }

    if (QK_VECS_PER_Dh_MAX > WARP_SIZE) {
        constexpr int WARPS_PER_RED = (QK_VECS_PER_Dh_MAX + WARP_SIZE - 1) / WARP_SIZE;
        qk                          = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
    }

    // Pre-compute the pointer for the relative attention bias.
    const T* relative_attention_bias_ptr       = nullptr;
    const T* relative_attention_bias_ptr_fixed = nullptr;  // record the base for offset
    if (has_relative_attention_bias) {
        // "hi" is unsigned, subtracting int from unsigned int causes underflow. Cast to int
        int64_t offset                    = implicit_rel_attn_bias ?
                                                ((int64_t)hi * relative_attention_bias_stride - tlength) :
                                                ((int64_t)hi * relative_attention_bias_stride + tlength) * relative_attention_bias_stride;
        relative_attention_bias_ptr       = &params.relative_attention_bias[offset];
        relative_attention_bias_ptr_fixed = &params.relative_attention_bias[offset];
    }

    // Load the value.
    float relative_attention_bias = 0.f;
    if (has_relative_attention_bias && tidx == 0) {
        // TODO: Use a better way to convert from T to float.
        relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[tlength]);
    }

    // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
    if (tidx == 0) {
        // Normalize qk.
        qk = qk * params.inv_sqrt_dh + relative_attention_bias;

        // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
        qk_max = qk;

        // Store Q*K^T to shared memory.
        if (MULTI_BLOCK_FLAG) {
            qk_current_smem[0] = qk;
        }
        else {
            // We need to store the qk result to the end of the qk_smem for cyclic kv cache (+ 1 for smem memory
            // allocation) because the previous cache will still write to the new_cache_pos of qk_smem.
            qk_smem[kv_loop_length] = qk;
        }
    }

    // Make sure the data is in shared memory.
    __syncthreads();

    constexpr unsigned K_ELTS_PER_CHUNK{THREADS_PER_KEY * K_VEC_SIZE};

    // The positions of the cache buffer (for this B * H) and the vector within that chunk associated with this
    // thread.
    const auto k_idx = chunk_index<T, K_vec_k, THREADS_PER_KEY>(tidx);

    // The number of vectors per thread.
    constexpr unsigned K_VECS_PER_THREAD{Dh_MAX / K_ELTS_PER_CHUNK};
    static_assert(Dh_MAX == K_ELTS_PER_CHUNK * K_VECS_PER_THREAD);

    // Load the Q values from shared memory. The values are reused during the loop on K.
    K_vec_accum q_vec[K_VECS_PER_THREAD];
#pragma unroll
    for (unsigned ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
        q_vec[ii] = vec_conversion<K_vec_accum, K_vec_k>(
            *reinterpret_cast<const K_vec_k*>(&q_smem[flat_index2(ii, k_idx.y, K_ELTS_PER_CHUNK)]));
    }

    // The number of timesteps loaded per iteration, i.e., (THREADS_PER_BLOCK * THREADS_PER_BLOCK) / 256 <= 256
    constexpr unsigned K_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_KEY};
    // The number of keys per warp.
    constexpr unsigned K_PER_WARP{WARP_SIZE / THREADS_PER_KEY};
    // The number of unrolled keys per warp.
    constexpr unsigned UNROLLED_K_PER_WARP = K_PER_WARP * K_LOOP_UNROLL;
    // The number of unrolled keys per ieration.
    constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL;

    // Base pointer for the row of pointers to k cache blocks

    const auto timesteps_per_block = static_cast<unsigned>(params.timesteps_per_block);

    // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
    // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
    const int context_length =
        DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
    // Clarifications:
    // - in self attn, input_length is input text length, tlength is current timestep
    // - in cross attn, input_length is *decoder* input length (usually 1), tlength is *encoder* input context length
    // - in beam search, since the cache during generation is organized differently, the following KV compute needs
    // split into context cache compute and generation cache compute
    // - for self attn, no-beam search: entire cache can be treated as context cache --> context_length = tlength
    // - for self attn, beam search: cache of input text length is context cache, other are generation cache -->
    // context_length = input_length
    // - for cross attn, no-beam/beam search: cache length is fixed, not differ context/generation cache -->
    // context_length = tlength Suggestion: we could have a flag HANDLE_GEN_CACHE

    const auto context_ti_end =
        MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP :
                           divUp(static_cast<unsigned>(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP;

    // The generation ti_end.
    const auto generation_ti_end = MULTI_BLOCK_FLAG ?
                                       divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP :
                                       divUp(static_cast<unsigned>(kv_loop_length), K_PER_WARP) * K_PER_WARP;

    // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
    // Note max_kv_cache_length is maximum of cyclic_kv_cache_length among all layers.
    // By default, you can assume that they are the same.
    const auto bi_seq_len_offset = static_cast<std::size_t>(batch_beam_idx) * params.max_kv_cache_length;
    // Beam indices are based on the max_kv_cache_length while each layer may have different cyclic_kv_cache_length
    // So we need to rebuild the beam_indices if max_kv_cache_length is not equal to cyclic_kv_cache_length.
    const int* beam_indices = HAS_BEAMS ? &params.cache_indir[bi_seq_len_offset] : nullptr;

    const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block;  // 0 if !MULTI_BLOCK_FLAG

    ////////////////////////////////////////////////////////////////////////////////////////////////
    // Key cache loops for dot(Q, K).

    // Is it the leader?
    const bool is_leader = Qk_dot<T, THREADS_PER_KEY>::is_leader(tidx);

    // The slope for ALiBi.
    float linear_bias_slope = 0.f;
    if (params.linear_bias_slopes != nullptr) {
        linear_bias_slope = params.linear_bias_slopes[hi];
    }

    // Handle only context key cache with beam searching.
    // Handle both context and generation key cache without beam searching.
    // Explicit batching of LDGs (by K_LOOP_UNROLL) as it doesn't depend on indirection tables.
    for (int ti = k_idx.x; ti < context_ti_end; ti += UNROLLED_K_PER_ITER) {
        const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;

        // The keys loaded from the key cache.
        K_vec_m k_vec_cache[K_LOOP_UNROLL][K_VECS_PER_THREAD];

#pragma unroll
        for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop) {
#pragma unroll
            for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i) {
                // Make sure we read data within the bound.
                // Dh OOB values will be handled by zero_q.
                // Seq OOB values will be masked out when storing back to smem.
                auto const jj             = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
                const int  valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
                const int  seqIdx         = batch_idx * beam_width;

                // Base pointer to k cache block for beam's batch
                Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));

                int inBlockIdx               = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
                k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]);
            }
        }

#pragma unroll
        for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop) {
            const int local_time_now = time_now + k_loop * K_PER_ITER;
            const int local_ti       = ti + k_loop * K_PER_ITER;

            float k_scale = 1.f;

            if constexpr (ENABLE_8BITS_CACHE) {
                const int valid_time_now = min(time_now + k_loop * K_PER_ITER, context_length - 1);
                const int seqIdx         = batch_idx * beam_width;
                float*    k_scale_ptr    = reinterpret_cast<float*>(kvCacheBuffer.getKScalePtr(seqIdx, valid_time_now));
                int       inScaleIdx     = kvCacheBuffer.getKVScaleLocalIdx(valid_time_now, hi_kv);
                k_scale                  = k_scale_ptr[inScaleIdx];
            }

            // Perform the dot product and normalize qk.
            //
            // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
            K_vec_m k_vec[K_VECS_PER_THREAD];
#pragma unroll
            for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i) {
                k_vec[k_vec_i] = *reinterpret_cast<K_vec_m*>(&k_vec_cache[k_loop][k_vec_i]);
            }

            // Is it active?
            const bool is_active = local_time_now < context_length;

            if (implicit_rel_attn_bias) {
                // Compute bias value on the fly (See bert_preprocess_kernels.cu::buildRelativeAttentionBias)
                int relative_buckets  = 0;
                int relative_position = local_time_now - tlength;
                int num_buckets       = relative_attention_bias_stride;
                // Special logic in T5 relative attention, both encoder & decoder use this, because
                // relative_attention_bias is pre-computed once and passed around.
                num_buckets /= 2;
                relative_buckets += relative_position > 0 ? num_buckets : 0;
                relative_position = abs(relative_position);
                int  max_exact    = num_buckets / 2;
                bool is_small     = relative_position < max_exact;
                int  relative_position_if_large =
                    max_exact
                    + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
                            * (num_buckets - max_exact));
                relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
                relative_buckets += is_small ? relative_position : relative_position_if_large;
                relative_attention_bias_ptr =
                    relative_attention_bias_ptr_fixed + (tlength - local_time_now) + relative_buckets;
            }

            // Prefetch the relative attention bias.
            float relative_attention_bias = 0.f;
            if (is_active && has_relative_attention_bias) {
                // TODO: Use a better way to convert from T to float.
                relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[local_time_now]);
            }

            // Compute the dot product between Q and K.
            // Note that dot will convert 8bit vec to the accumulation data type (float by default).
            float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
            if constexpr (FP8_KV_CACHE) {
                qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
            }
            else
#endif  // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
            {
                if constexpr (ENABLE_8BITS_CACHE) {
                    qk_ =
                        Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, k_scale) * params.inv_sqrt_dh;
                }
                else {
                    qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
                }
            }

            // For multi-block mode, we need to make sure it will not be OOB.
            if (MULTI_BLOCK_FLAG && local_ti >= timesteps_per_block) {
                continue;
            }

            // Add the ALiBi bias. (ki - qi) * slope[hi].
            //
            // The padding tokens are located between the input context and the generated tokens.
            // We need to remove the correct number of padding tokens in the distance computation.
            //
            //   ti   : 0 1 2 3 4 5 6 7 8 9(tlength)
            //   token: i i i i p p p o o o where i=input, p=pad, o=output.
            // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
            //
            // All the threads do the work even if it's not relevant to avoid divergence.
            qk_ += linear_bias_slope * (local_time_now - tlength) + relative_attention_bias;

            // There's one qk value per timestep.
            // Make sure only leader threads stores qk value within the bound.
            if (is_active && is_leader) {
                // Calculate the max for softmax.
                qk_max = fmaxf(qk_max, qk_);
                // Store the product to shared memory.
                qk_smem[local_ti] = qk_;
            }
        }
    }

    // Handle generation key cache with beam searching.
    // Note that it may be overlapped with the context key loop, but it won't impact the corretness.
    // Can skip in cross attention mode.
    if (HAS_BEAMS && !DO_CROSS_ATTENTION
        && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length)) {
        // The input length;
        const int input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length;
        // The beginning of the generation.
        const int generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP;

        // Iterate over the output tokens.
        for (int ti = generation_start_ti; ti < generation_ti_end; ti += K_PER_ITER) {
            const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;

            // The keys loaded from the key cache.
            K_vec_m k_vec[K_VECS_PER_THREAD];

#pragma unroll
            for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i) {
                const int jj             = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE);
                const int valid_time_now = min(time_now, kv_loop_length - 1);
                int       beam_offset    = beam_indices[valid_time_now];
                const int seqIdx         = batch_idx * beam_width + beam_offset;
                // Base pointer to k cache block for beam's batch, before offsetting with indirection buffer
                Tcache* k_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(seqIdx, valid_time_now));

                int inBlockIdx = kvCacheBuffer.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj);
                if constexpr (ENABLE_8BITS_CACHE && !FP8_KV_CACHE) {
                    float* k_scale_ptr = reinterpret_cast<float*>(kvCacheBuffer.getKScalePtr(seqIdx, valid_time_now));
                    int    inScaleIdx  = kvCacheBuffer.getKVScaleLocalIdx(valid_time_now, hi_kv);
                    load_8bits_kv_cache_vec(
                        &k_vec[k_vec_i], k_cache_batch, inBlockIdx, k_scale_ptr[inBlockIdx - inScaleIdx]);
                }
                else {
                    k_vec[k_vec_i] = (*reinterpret_cast<const K_vec_m*>(&k_cache_batch[inBlockIdx]));
                }
            }

            // Is it active?
            const bool is_active = time_now >= context_length && time_now < kv_loop_length;

            if (implicit_rel_attn_bias) {
                // Compute bias value on the fly (See bert_preprocess_kernels.cu::buildRelativeAttentionBias)
                int relative_buckets  = 0;
                int relative_position = time_now - tlength;
                int num_buckets       = relative_attention_bias_stride;
                // Special logic in T5 relative attention, both encoder & decoder use this, because
                // relative_attention_bias is pre-computed once and passed around.
                num_buckets /= 2;
                relative_buckets += relative_position > 0 ? num_buckets : 0;
                relative_position = abs(relative_position);
                int  max_exact    = num_buckets / 2;
                bool is_small     = relative_position < max_exact;
                int  relative_position_if_large =
                    max_exact
                    + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact)
                            * (num_buckets - max_exact));
                relative_position_if_large = min(relative_position_if_large, num_buckets - 1);
                relative_buckets += is_small ? relative_position : relative_position_if_large;
                relative_attention_bias_ptr =
                    relative_attention_bias_ptr_fixed + (tlength - time_now) + relative_buckets;
            }

            // Prefetch the relative attention bias.
            float relative_attention_bias = 0.f;
            if (is_active && has_relative_attention_bias) {
                // TODO: Use a better way to convert from T to float.
                relative_attention_bias = add(relative_attention_bias, relative_attention_bias_ptr[time_now]);
            }

            // Perform the dot product and normalize qk.
            //
            // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
            // Note that dot will convert 8bit vec to the accumulation data type (float by default).
            float qk_ = 0.f;
#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K
            if constexpr (FP8_KV_CACHE) {
                qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
            }
            else
#endif  // MMHA_FP8_SCALE_Q_INSTEAD_OF_K
            {
                if constexpr (ENABLE_8BITS_CACHE) {
                    qk_ =
                        Qk_dot<T, THREADS_PER_KEY>::scale_dot(q_vec, k_vec, kv_scale_quant_orig_f) * params.inv_sqrt_dh;
                }
                else {
                    qk_ = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * params.inv_sqrt_dh;
                }
            }
            // Add the ALiBi bias. (ki - qi) * slope[hi].
            //
            // The padding tokens are located between the input context and the generated tokens.
            // We need to remove the correct number of padding tokens in the distance computation.
            //
            //   ti   : 0 1 2 3 4 5 6 7 8 9(tlength)
            //   token: i i i i p p p o o o where i=input, p=pad, o=output.
            // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
            //
            // All the threads perform that step to avoid divergence.
            qk_ += linear_bias_slope * (time_now - tlength) + relative_attention_bias;

            // There's one qk value per timestep.
            // Make sure only leader threads stores qk value within the bound.
            if (is_active && is_leader) {
                // Calculate the max for softmax.
                qk_max = fmaxf(qk_max, qk_);
                // Store the product to shared memory.
                qk_smem[ti] = qk_;
            }
        }
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////
    // Softmax.

    // Perform the final reduction to compute the max inside each warp.
    //
    // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
    // group so it's not needed to run the reduction inside the group (again).

#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA)
    // Leader threads will be in the dignonal when using HMMA.
    if (THREADS_PER_KEY <= 4) {
        qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 4));
    }
    if (THREADS_PER_KEY <= 8) {
        qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 9));
    }
    if (THREADS_PER_KEY <= 16) {
        qk_max = fmaxf(qk_max, __shfl_xor_sync(unsigned(-1), qk_max, 18));
    }
#else
#pragma unroll
    for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
        qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
    }
#endif  // defined MMHA_USE_HMMA

    // Decompose the thread index into warp and lane.
    const auto warp = tidx / WARP_SIZE;
    const auto lane = tidx % WARP_SIZE;

    // The warp leader writes the max to shared memory.
    if (lane == 0) {
        red_smem[warp] = qk_max;
    }

    // Make sure the products are in shared memory.
    __syncthreads();

    // After the syncthreads, the target k position (cyclic kv cache) should also have been used by the k loop.
    // Write the K values to the global memory cache.
    //
    // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
    // system. We designed it this way as it allows much better memory loads (and there are many
    // more loads) + the stores are really "write and forget" since we won't need the ack before
    // the end of the kernel. There's plenty of time for the transactions to complete.

    // For MQA/GQA mode, write only with the first Q head of each group per KV head.
    if (HANDLE_KV && hi == (hi_kv * qhead_per_kv) && qk_vec_idx < Dh) {
        // Trigger the stores to global memory.
        Qk_vec_k   k_vec      = *reinterpret_cast<Qk_vec_k*>(&k_smem[qk_vec_idx]);
        const auto k_idx      = QK_VEC_SIZE * tidx;
        const int  inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx);
        // The base pointer for the value in the cache buffer.
        Tcache* k_cache = reinterpret_cast<Tcache*>(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength));

        if constexpr (ENABLE_8BITS_CACHE) {
            float k_max = vector_abs_max(k);
            float scale = 1.0f;
            if constexpr (std::is_same<Tcache, int8_t>::value) {
#pragma unroll
                for (int mask = QK_VECS_PER_Dh_MAX / 2; mask >= 1; mask /= 2) {
                   k_max = fmaxf(k_max, __shfl_xor_sync(shfl_mask(QK_VECS_PER_Dh_MAX), k_max, mask));
                }
                scale = float(1 << (8 - 1)) / k_max;
            }
            store_8bits_kv_cache_vec(reinterpret_cast<Tcache*>(k_cache), k, inBlockIdx, scale);
            if (k_idx == 0) {
                float* k_scale_ptr = reinterpret_cast<float*>(kvCacheBuffer.getKScalePtr(batch_beam_idx, tlength));
                int    inScaleIdx  = kvCacheBuffer.getKVScaleLocalIdx(tlength, hi_kv);
                *reinterpret_cast<float*>(&k_scale_ptr[inScaleIdx]) = 1.0f / scale;
            }
        }
        else {
            *reinterpret_cast<Qk_vec_m*>(&k_cache[inBlockIdx]) = vec_conversion<Qk_vec_m, Qk_vec_k>(k_vec);
        }
    }

    // The warps finalize the reduction.
    qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
#pragma unroll
    for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
        qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
    }

    // Broadcast to all the threads in the warp.
    qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);

    // Compute the logits and start the sum.
    float sum = 0.f;

    // Each thread will handle one float (either qk_smem/logit).
    const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
    for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK) {

        const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;

        // For single-block mode, we don't need the mask since it has been skipped.
        if (!MULTI_BLOCK_FLAG) {
            float logit = __expf(qk_smem[time_now] - qk_max);
            sum += logit;
            qk_smem[time_now] = logit;
        }
        else {
            // Not supported yet: multi-block mode with FP8_MHA
            if (time_now < kv_loop_length && ti != timesteps_per_block) {
                float logit = __expf(qk_smem[ti] - qk_max);
                sum += logit;
                qk_smem[ti] = logit;
            }
            else if (time_now == kv_loop_length) {
                float logit = __expf(qk_current_smem[0] - qk_max);
                sum += logit;
                qk_current_smem[0] = logit;
            }
        }
    }

    // Compute the sum.
    sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);

// Normalize the logits.
#ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V
    // float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f);
    float logit_scale = 1.0f;
#else
    float logit_scale = 1.f;
#endif  // MMHA_FP8_SCALE_P_INSTEAD_OF_V
    float inv_sum = __fdividef(logit_scale, sum + 1.e-6f);

    const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
    for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK) {
        const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti;

        if (!MULTI_BLOCK_FLAG) {
            convert_from_float(&logits_smem[ti], qk_smem[ti] * inv_sum);
        }
        else {
            // no scaling factor inv_sum applied here, will apply the scaling factor after all blocks finished
            if (time_now < kv_loop_length && ti != timesteps_per_block) {
                convert_from_float(&logits_smem[ti], qk_smem[ti]);
            }
            else if (time_now == kv_loop_length) {
                convert_from_float(&logits_current_smem[0], qk_current_smem[0]);
            }
        }
    }

    // Put Values part below so we leverage __syncthreads
    // from the previous step

    const auto v_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
    // The value computed by this thread.
    const auto vo = v_idx.x;
    // The hidden dimensions computed by this particular thread.
    const auto vi = v_idx.y;

    // The number of values processed per iteration of the loop.
    constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE};
    // The number of unrolled keys per ieration.
    constexpr unsigned UNROLLED_V_PER_ITER = V_PER_ITER * V_LOOP_UNROLL;

    bool const is_valid_vi = IS_Dh_MAX || vi < Dh;

    // One group of threads computes the product(s) for the current timestep.
    V_vec_k v_bias;
    zero(v_bias);
    // if( vo == params.timestep % V_PER_ITER ) {
    if (is_valid_vi && HANDLE_KV && vo == kv_loop_length % V_PER_ITER) {
        // Trigger the loads from the V bias buffer.
        if (params.v_bias != nullptr) {
            const auto v_bias_offset = flat_index2(hi_kv, vi, Dh);
            v_bias                   = *reinterpret_cast<const V_vec_k*>(&params.v_bias[v_bias_offset]);
        }

        if (DO_CROSS_ATTENTION) {
            *reinterpret_cast<V_vec_k*>(&bias_smem[vi]) = v_bias;
        }
    }

    // From previous, before values, step
    // Also make sure the logits are in shared memory.
    __syncthreads();

    ////////////////////////////////////////////////////////////////////////////////////////////////
    // Value cache loops.

#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
    using V_vec_accum = typename V_vec_accum_fp32_<V_vec_k>::Type;
#else
    using V_vec_accum = V_vec_k;
#endif
    // The partial outputs computed by each thread.
    V_vec_accum out;
    zero(out);

    // Loop over the timesteps to compute the partial outputs.
    if (is_valid_vi) {
        // Handle only context value cache with beam searching.
        // Handle both context and generation value cache without beam searching.
        // Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables.
        // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible.
        const int context_length =
            DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length);
        int context_v_loop_end    = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length;
        int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length;
        for (int ti = vo; ti < context_v_loop_end; ti += UNROLLED_V_PER_ITER) {
            V_vec_m v_vec_cache[V_LOOP_UNROLL];
#pragma unroll
            for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++) {
                // Fetch offset based on cache_indir when beam sampling
                int time_idx = ti + v_loop * V_PER_ITER + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
                time_idx     = min(time_idx, kv_loop_length - 1);
                int rowIdx   = batch_idx * beam_width;

                const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
                // The base pointer for the value in the cache buffer.
                Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));

                v_vec_cache[v_loop] = *reinterpret_cast<const V_vec_m*>(&v_cache_batch[inBlockIdx]);
            }

#pragma unroll
            for (int v_loop = 0; v_loop < V_LOOP_UNROLL; v_loop++) {
                float v_scale = 1.f;
                V_vec_m v_vec = reinterpret_cast<V_vec_m*>(&v_vec_cache[v_loop])[0];
                int rowIdx   = batch_idx * beam_width;

                int local_time_idx = ti + v_loop * V_PER_ITER;
                int time_idx       = local_time_idx + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
                const bool is_mask =
                    (MULTI_BLOCK_FLAG && local_time_idx >= timesteps_per_block) || (time_idx >= context_length);

                if constexpr (ENABLE_8BITS_CACHE) {
                    int    local_time  = min(time_idx, tlength - 1);
                    float* v_scale_ptr = reinterpret_cast<float*>(kvCacheBuffer.getVScalePtr(rowIdx, local_time));
                    int    inScaleIdx  = kvCacheBuffer.getKVScaleLocalIdx(local_time, hi_kv);
                    v_scale            = v_scale_ptr[inScaleIdx];
                }

                // Load the logits from shared memory.
                // Note that fma will convert 8bit vec to the accumulation data type (float by default).
                Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
                    out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, v_scale, is_mask);
            }
        }

        // Handle generation value cache with beam searching.
        if (HAS_BEAMS && !DO_CROSS_ATTENTION) {
            const auto generation_start_ti =
                MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER);
            // Only the last few blocks need to handle the generation value cache.
            if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length) {
                for (int ti = generation_start_ti; ti < generation_v_loop_end; ti += V_PER_ITER) {
                    // Fetch offset based on cache_indir when beam sampling
                    int time_idx       = ti + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0);
                    int local_time_idx = ti;
                    if (time_idx < beam0_context_length || (MULTI_BLOCK_FLAG && time_idx >= kv_loop_length)) {
                        continue;
                    }
                    int rowIdx = batch_idx * beam_width + beam_indices[time_idx];

                    const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi);
                    // The base pointer for the value in the cache buffer.
                    Tcache* v_cache_batch = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx));
                    // V_vec_m v_vec;
                    V_vec_m    v_vec = reinterpret_cast<const V_vec_m*>(&v_cache_batch[inBlockIdx])[0];
                    float v_scale = 1.f;
                    if (ENABLE_8BITS_CACHE) {
                        float* v_scale_ptr = reinterpret_cast<float*>(kvCacheBuffer.getVScalePtr(rowIdx, time_idx));
                        int    inScaleIdx  = kvCacheBuffer.getKVScaleLocalIdx(time_idx, hi_kv);
                        v_scale =  v_scale_ptr[inScaleIdx];
                    }

                    // Load the logits from shared memory.
                    // Note that fma will convert 8bit vec to the accumulation data type (float by default).
                    Logit_value_fma<Tk, V_vec_accum, V_vec_m, INT8_KV_CACHE, FP8_KV_CACHE>(
                        out, reinterpret_cast<Tk*>(logits_smem + local_time_idx), v_vec, v_scale, false);
                }
            }
        }
    }

    // Make sure we can overwrite the v cache if using cyclic kv cache.
    __syncthreads();

    // Get the c_tile_id that handles the current timestep.
    const int ctile_idx = tlength / timesteps_per_block;

    // One group of threads computes the product(s) for the current timestep.
    if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx))) {
        const int tokenIdx   = cyclic_tlength;
        const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(tokenIdx, hi_kv, Dh, vi);
        // The base pointer for the value in the cache buffer.
        Tcache* v_cache_base = reinterpret_cast<Tcache*>(kvCacheBuffer.getVBlockPtr(batch_beam_idx, cyclic_tlength));

        V_vec_k v;
        if (DO_CROSS_ATTENTION) {
            v = vec_conversion<V_vec_k, V_vec_k>(*reinterpret_cast<const V_vec_k*>(&v_cache_base[inBlockIdx]));
        }
        else {
            // Trigger the loads from the V buffer.
            // The stride between tokens. We may be able to always use params.stride.
            uint32_t v_stride = params.stride ? static_cast<uint32_t>(params.stride) : (num_heads_kv * Dh);
            // The offset.
            const auto v_offset = flat_index_strided3(batch_beam_idx, hi_kv, vi, v_stride, Dh);

            if (load_qkv_quant) {
                using Packed_Int8_t  = typename packed_type<int8_t, num_elems<V_vec_k>::value>::type;
                using Packed_Float_t = typename packed_type<float, num_elems<V_vec_k>::value>::type;
                const auto v_scaling = params.qkv_scale_quant_orig[2];
                const auto v_quant =
                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);

                convert_from_float(&v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
            }
            else {
                v = *reinterpret_cast<const V_vec_k*>(&params.v[v_offset]);
            }
        }

        if (HANDLE_KV) {
            // Compute the V values with bias.
            v = add(v, v_bias);

            if (do_ia3) {
                v = mul<V_vec_k, V_vec_k, V_vec_k>(
                    v, *reinterpret_cast<const V_vec_k*>(&params.ia3_value_weights[flat_index2(ia3_ti_hi, vi, Dh)]));
            }
        }

        // Store the values with bias back to global memory in the cache for V.
        //*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
        // For MQA/GQA mode, write only with the first Q head of each group per KV head.
        if (hi == (hi_kv * qhead_per_kv)) {
            if constexpr (ENABLE_8BITS_CACHE) {
                float v_max = vector_abs_max(v);
                float scale = 1.0f;
                if constexpr (std::is_same<Tcache, int8_t>::value) {
#pragma unroll
                   for (int mask = THREADS_PER_VALUE / 2; mask >= 1; mask /= 2) {
                    v_max =
                        fmaxf(v_max,
                              __shfl_xor_sync(
                                  shfl_mask_and_index(THREADS_PER_VALUE, vo % (32 / THREADS_PER_VALUE)), v_max, mask));
                    }
                    scale = float(1 << (8 - 1)) / v_max;
                }
                store_8bits_kv_cache_vec(v_cache_base, v, inBlockIdx, scale);
                if (vi == 0) {
                    float* v_scale_ptr = reinterpret_cast<float*>(kvCacheBuffer.getVScalePtr(batch_beam_idx, tokenIdx));
                    int    inScaleIdx  = kvCacheBuffer.getKVScaleLocalIdx(tokenIdx, hi_kv);
                    *reinterpret_cast<float*>(&v_scale_ptr[inScaleIdx]) = 1 / scale;
                }
            }

            else {
                *reinterpret_cast<V_vec_k*>(&v_cache_base[inBlockIdx]) = v;
            }
        }

        // Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACCUM_FOR_LOGITS)
        // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
        if (!MULTI_BLOCK_FLAG) {
            out = fma(logits_smem[kv_loop_length], cast_to_float(v), out);
        }
        else {
            out = fma(logits_current_smem[0], cast_to_float(v), out);
        }
#else   // MMHA_USE_FP32_ACCUM_FOR_LOGITS
        // out = fma(logits_smem[params.timestep], v, out);
        if (!MULTI_BLOCK_FLAG) {
            out = fma(logits_smem[kv_loop_length], v, out);
        }
        else {  // MULTI_BLOCK_FLAG // Not supported yet: multi-block mode with FP8_MHA
            out = fma(logits_current_smem[0], v, out);
        }
#endif  // MMHA_USE_FP32_ACCUM_FOR_LOGITS
    }
    // Make sure we can start writing to shared memory.
    __syncthreads();

    // Run the final reduction amongst the different groups computing different partial outputs.
#pragma unroll
    for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {

        // The midpoint in the number of active groups.
        int midpoint = active_groups / 2;

        // The upper part of active threads store to shared memory.
        if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
            convert_from_float(reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
#else
            *reinterpret_cast<V_vec_k*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
#endif
        }
        __syncthreads();

        // The bottom warps update their values.
        if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
            out = add(*reinterpret_cast<const V_vec_k*>(&out_smem[vo * Dh + vi]), out);
        }
        __syncthreads();
    }
#ifdef ENABLE_FP8
    // Quantized output only supports fp8 currently, which should be used together with FP8 Context FMHA.
    using Quantized_t = __nv_fp8_e4m3;
    using Quantized_vec = typename packed_type<__nv_fp8_e4m3, num_elems<V_vec_accum>::value>::type;
#endif
    const auto bhi              = flat_index2(batch_beam_idx, hi, num_heads);
    const auto bhi_seq_len_tile = bhi * params.seq_len_tile;
    // Output the final values.
    if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
        const auto bhvi = flat_index2(bhi, vi, Dh);
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
	if (!MULTI_BLOCK_FLAG)
        {
            if (write_attention_quant)
            {
                out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
#ifdef ENABLE_FP8
                Quantized_vec final_out;
                convert_to_fp8(&final_out, out);
                *reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhvi) = final_out;
#endif
            }
            else
            {
                // This makes sure we have coalesced memory access.
                V_vec_k final_out;
                convert_from_float(&final_out, out);
                *reinterpret_cast<V_vec_k*>(&params.out[bhvi]) = final_out;
            }
        }
        else
        {
            // for write partial output to partial_out
            int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
            // for write partial statistics to partial_max and partial_sum
            int partial_stats_offset = bhi_seq_len_tile + c_tile;

            // This makes sure we have coalesced memory access.
            V_vec_k partial_out;
            convert_from_float(&partial_out, out);
            *reinterpret_cast<V_vec_k*>(&params.partial_out[partial_out_offset + bhvi]) = partial_out;
            convert_from_float(reinterpret_cast<float*>(&params.partial_max[partial_stats_offset]), qk_max);
            convert_from_float(reinterpret_cast<float*>(&params.partial_sum[partial_stats_offset]), sum);
        }

#else   // MMHA_USE_FP32_ACCUM_FOR_OUT
        *reinterpret_cast<V_vec_accum*>(&params.out[bhvi]) = out;
#endif  // MMHA_USE_FP32_ACCUM_FOR_OUT
    }

#ifdef ENABLE_MULTI_BLOCK_OPTION
    if (MULTI_BLOCK_FLAG) {

        cuda::atomic_ref<int, cuda::thread_scope_device> count_ref{params.block_counter[bhi]};
        bool                                             last_block{false};
        if (tidx == 0) {
            if (count_ref.fetch_add(1, cuda::memory_order_acq_rel) == (gridDim.z - 1)) {
                last_block = true;
            }
        }

        ////////////////////
        ////////////////////
        // Make sure every threadblock finishes the previous computation, and enter the last threadblock in the
        // following (for each B and H) Do the final computation in the last threadblock Final reduction computation
        // by combining all the partial max/sum and outputs
        ////////////////////
        ////////////////////
        if (__syncthreads_or(last_block)) {

            ////////////////////
            // Find the global max from all partial max -> use CUB BlockReduce
            ////////////////////

            float final_max          = -FLT_MAX;
            float thread_partial_max = -FLT_MAX;
            thread_partial_max       = params.partial_max[bhi_seq_len_tile + min(tidx, gridDim.x - 1)];

            // Make sure we can start writing to shared memory.
            __syncthreads();

            // Specialize BlockReduce for a 1D block of THREADS_PER_BLOCK threads of type int
            typedef cub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
            // Allocate shared memory for BlockReduce
            __shared__ typename BlockReduce::TempStorage temp_storage;
            // Obtain a segment of consecutive items that are blocked across threads (final_max from above)
            // Compute the block-wide max for thread0
            final_max = BlockReduce(temp_storage).Reduce(thread_partial_max, cub::Max(), gridDim.z);

            __shared__ float final_max_smem;
            if (tidx == 0) {
                final_max_smem = final_max;
            }
            __syncthreads();

            // Finish the final_max computation
            final_max = final_max_smem;

            ////////////////////
            // Reduction for global sum over all partial sum (scaled by the exponential term from global max) -> use
            // gridDim.z threads
            ////////////////////

            float final_sum = 0.f;
            if (tidx < gridDim.z) {
                thread_partial_max            = params.partial_max[bhi_seq_len_tile + tidx];
                const auto thread_partial_sum = params.partial_sum[bhi_seq_len_tile + tidx];
                final_sum += __expf(thread_partial_max - final_max) * thread_partial_sum;
            }

            // Compute the final_sum.
            final_sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], final_sum);

            ////////////////////
            // Reduction for final output (scaled by the exponential term from global max) -> use THREADS_PER_VALUE
            // * gridDim.z threads
            ////////////////////

            // Shared memory to store partial outputs for each oi. -> size: gridDim.z * Dh * 4 Bytes. Reuse qk_smem.
            T* out_oi_smem = reinterpret_cast<T*>(smem_);

            const auto o_idx = chunk_index<T, V_vec_k, THREADS_PER_VALUE>(tidx);
            // The partial output region this thread takes care of
            const auto oo = o_idx.x;
            // The hidden dimensions computed by this particular thread. (refer to vi)
            const auto oi = o_idx.y;

            // Within the bound.
            const bool within_bound = oo < gridDim.z;

            // Load partial output
            int thread_partial_out_offset = oo * params.batch_size * num_heads * params.hidden_size_per_head;
            // Load partial max (different to thread_partial_max since the threadIdx rule changes here)
            float thread_partial_max_for_out = within_bound ? params.partial_max[bhi_seq_len_tile + oo] : final_max;

            // Load the partial outputs.
            V_vec_k zero_k;
            zero(zero_k);
            V_vec_k thread_partial_out =
                within_bound ?
                    *reinterpret_cast<const V_vec_k*>(&params.partial_out[thread_partial_out_offset + bhi * Dh + oi]) :
                    zero_k;

            Tk factor_compute;
            convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max));
            thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(factor_compute, thread_partial_out);

            // Make sure we can start writing to shared memory.
            __syncthreads();

            // The reduction iteration should start with a number which is a power of 2
            const auto reduction_iteration = static_cast<int>(cuda::std::bit_ceil(gridDim.z));

            // Run the final reduction amongst the different groups computing different partial outputs.
#pragma unroll
            for (int active_groups = reduction_iteration; active_groups >= 2; active_groups /= 2) {

                // The midpoint in the number of active groups.
                int midpoint = active_groups / 2;

                // The upper part of active threads store to shared memory.
                if (oo >= midpoint && oo < active_groups && (Dh == Dh_MAX || oi < Dh)) {
                    *reinterpret_cast<V_vec_k*>(&out_oi_smem[(oo - midpoint) * Dh + oi]) = thread_partial_out;
                }
                __syncthreads();

                // The bottom warps update their values.
                if (oo < midpoint && (Dh == Dh_MAX || oi < Dh)) {
                    thread_partial_out =
                        add(thread_partial_out, *reinterpret_cast<const V_vec_k*>(&out_oi_smem[oo * Dh + oi]));
                }
                __syncthreads();
            }

            ////////////////////
            // Final output O * inv_sum
            ////////////////////

            if (oo == 0 && (Dh == Dh_MAX || oi < Dh)) {
                const auto inv_sum = __fdividef(1.f, final_sum + 1.e-6f);

                Tk inv_sum_compute;
                convert_from_float(&inv_sum_compute, inv_sum);

                thread_partial_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_partial_out);
                if (write_attention_quant)
                {
                    thread_partial_out = mul<V_vec_k, float>(*params.attention_out_scale_orig_quant, thread_partial_out);
#ifdef ENABLE_FP8
                    Quantized_vec final_out;
                    convert_to_fp8(&final_out, thread_partial_out);
                    *reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhi * Dh + oi)
                        = final_out;
#endif
                }
                else
                {
                    *reinterpret_cast<V_vec_k*>(&params.out[bhi * Dh + oi]) = thread_partial_out;
                }
            }

            // Reset qk_current_smem and block_counter for the next timestep
            if (tidx == 0) {
                params.block_counter[bhi] = 0;
            }
        }
    }
#endif  // ENABLE_MULTI_BLOCK_OPTION
}