AttentionBlockOutputs GptModel::forwardAttentionBlock()

in maga_transformer/cpp/models/GptModel.cc [849:1013]


AttentionBlockOutputs GptModel::forwardAttentionBlock(
        const GptLayerInputs& inputs,
        const int32_t layer_id,
        rtp_llm::lora::LoraModelInputPtr lora_model_input,
        const LastLayerDeferedParams& last_layer_defered_params)
{
    auto hidden = inputs.hidden;
    auto pre_decoder_residual = inputs.pre_decoder_residual;
    auto attention_common_inputs = move(inputs.attention_common_inputs);
    DevicePerfWrapper wrapper(device_, "attention_block_layer_" + std::to_string(layer_id) + "_bs_" + std::to_string(hidden->shape()[0]));

    if (last_layer_defered_params.combine_output) {
        printBufferData(*(last_layer_defered_params.combine_output.value().all_output), "layer_" + to_string(layer_id - 1) + "_combine_output_defered");
        hidden = device_->gatherCombineOutput(last_layer_defered_params.combine_output.value()).hidden_states;
    }

    attention_common_inputs.layer_id = layer_id;
    const auto& layer = weights_.layers[layer_id];
    bool enable_sp = inputs.enable_sp;

    // here hidden->dtype maybe int8, so use dytpe of embedding lookup result instead
    size_t rank_pad_token_num = enable_sp ? inputs.pad_token_num / device_props_.tp_size : hidden->shape()[0];
    BufferPtr attn_out_buf = device_->allocateBuffer({inputs.dtype, {rank_pad_token_num, hidden->shape()[1]}}, {"attn_out_buf"});
    if (!enable_sp) {
        // Note: for custom all reduce, prepareAllReduce will replace the original attn_out_buf with
        // a new custom_ar_comm buffer. Here we must make sure that attn_out_buf is not released or replaced by
        // other buffer before the actual allreduce operations. Otherwise, it will raise an error in custom ar.
        attn_out_buf = device_->prepareAllReduce({std::move(attn_out_buf), ReduceOp::Sum}).buffer;
    }
    auto residual = pre_decoder_residual ? pre_decoder_residual : hidden;
    printBufferData(*residual, "in residual");
    BufferPtr residual2 = nullptr;
    BufferPtr hidden_to_slice = nullptr; // for sp and overlap comm type 2
    if (layer.pre_layernorm) {
        // TODO(wangyin.yx): fuse this clone branch into layernorm(rmsnorm)
        residual = last_layer_defered_params.residual ? device_->allocateBufferLike(*hidden, AllocationType::DEVICE, {"residual"})
                                                      : device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
        int m_split = device_props_.m_split;
        size_t overlap_comm_type = device_props_.overlap_comm_type;
        auto pre_layernorm_output = device_->layernorm(LayernormParams(hidden,
                                                                        residual,
                                                                        *layer.pre_layernorm,
                                                                        rtp_llm::mayGetRef(last_layer_defered_params.residual),
                                                                        rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
                                                                        std::nullopt,
                                                                        0.f,
                                                                        description_.layernorm_eps,
                                                                        (enable_sp && overlap_comm_type == 2) ? false : true,
                                                                        false,
                                                                        description_.norm_type,
                                                                        description_.act_qscheme,
                                                                        layer_id > 0 ? true: false,
                                                                        false));

        if (enable_sp && layer_id == 0) {
            if (overlap_comm_type == 1 && m_split > 0) {
                vector<int> selected_indices;
                selected_indices.reserve(rank_pad_token_num);
                size_t m = inputs.pad_token_num;
                size_t m_chunk = m / m_split;
                if (m > 128) {
                    m_chunk = (m / m_split + 127) & ~127;
                }
                size_t tp_rank = device_props_.tp_rank;
                size_t round = m_chunk / device_props_.tp_size;

                size_t offset = tp_rank * round;
                for (size_t i = 0; i < rank_pad_token_num; i++) {
                    selected_indices.push_back( (i / round) * m_chunk + i % round + offset);
                }
                // printBufferData(*vector2Buffer(selected_indices), "selected_indices");
                residual = device_->select({*residual, *device_->clone({*vector2Buffer(selected_indices)})});
            } else {
                hidden_to_slice = residual;
                residual = residual->slice(rank_pad_token_num * device_props_.tp_rank, rank_pad_token_num);
            }
        }

        hidden = std::move(pre_layernorm_output.output);
    } else if (last_layer_defered_params.residual || last_layer_defered_params.shared_expert_output) {
        // NOTE(wangyin): this branch is not used for now, might be errornous
        residual = device_->clone({*hidden, AllocationType::DEVICE, {"residual"}});
        auto prev_ffn_layernorm_output = device_->layernorm({
            hidden,
            nullptr,
            std::nullopt, // post_ffn_layernorm_weights
            rtp_llm::mayGetRef(last_layer_defered_params.residual),
            rtp_llm::mayGetRef(last_layer_defered_params.shared_expert_output),
        });
        hidden = std::move(prev_ffn_layernorm_output.output);
    }

    if (k_cache_buffer_ && attention_common_inputs.kv_cache) {
        attention_common_inputs.kv_cache->k_cache_buffer = k_cache_buffer_->index(layer_id);
        attention_common_inputs.kv_cache->v_cache_buffer = v_cache_buffer_->index(layer_id);
        if (k_scale_buffer_) {
            attention_common_inputs.kv_cache->k_scale_buffer = k_scale_buffer_->index(layer_id);
            attention_common_inputs.kv_cache->v_scale_buffer = v_scale_buffer_->index(layer_id);
        }
    }
    if (lora_model_input) {
        attention_common_inputs.lora_input = lora_model_input->getAttentionLayerLoraInput(layer_id);
    }
    AttentionLayerOutput attn_output;
    auto attn_params = AttentionLayerParams({
            layer_id,
            *hidden,
            move(attn_out_buf),
            description_.attention_conf,
            layer.self_attention_weights,
            attention_common_inputs,
            device_props_.attn_fuse_add_residual ? (OptionalConstBufferRef)*residual : nullopt,
            {description_.layernorm_eps, description_.norm_type},
            description_.act_qscheme,
            enable_sp,
            inputs.pad_token_num
    });
    if (description_.attention_conf.use_mla && device_->mla_ops_type != rtp_llm::MlaOpsType::MHA) {
        attn_output = device_->mlaAttentionLayer(attn_params);
    } else {
        attn_output = device_->attentionLayer(attn_params);
    }

    auto attn_hidden = std::move(attn_output.hidden_states);
    if (device_props_.tp_size > 1 && !enable_sp) {
        // Note: for custom all reduce, allReduce will allocate a new buffer and replace the original attn_hidden with it
        auto wrapper = DevicePerfWrapper(device_, "allReduce, sizeBytes=%ld", (long)attn_hidden->sizeBytes());
        attn_hidden = device_->allReduce({std::move(attn_hidden), ReduceOp::Sum}).buffer;
    }
    if (residual_scale_) {
        attn_hidden = device_->multiply({*residual_scale_, *attn_hidden});
    }
    printBufferData(*attn_hidden, "layer_" + to_string(layer_id) + "_attn_output");

    if (layer.post_layernorm) {
        // attn_hidden = attn_hidden + residual
        // hidden = layernorm(attn_hidden)
        printBufferData(*residual, "before post layernorm residual");
        auto post_layernorm_params = LayernormParams(attn_hidden,
                                                        attn_hidden,
                                                        rtp_llm::mayGetRef(layer.post_layernorm),
                                                        device_props_.attn_fuse_add_residual ? nullopt : (OptionalConstBufferRef)*residual,
                                                        nullopt,
                                                        rtp_llm::mayGetRef(layer.self_attention_weights.output_weight->bias),
                                                        0.f,
                                                        description_.layernorm_eps,
                                                        false,
                                                        description_.post_layernorm,
                                                        description_.norm_type,
                                                        description_.act_qscheme,
                                                        false,
                                                        true);

        auto post_layernorm_output = device_->layernorm(post_layernorm_params);
        hidden = std::move(post_layernorm_output.output);
        attn_hidden = std::move(post_layernorm_output.before_norm_output);
        residual = attn_hidden;
        printBufferData(*residual, "after post layernorm residual");
    } else {
        residual2 = attn_hidden;
    }

    printBufferData(*hidden, "layer_" + to_string(layer_id) + "_ffn_input");
    return {hidden, residual, residual2};
}